13from detectors.MTCDetector
import MTCDetector
14from detectors.SBTDetector
import SBTDetector
15from detectors.splitcalDetector
import splitcalDetector
16from detectors.strawtubesDetector
import strawtubesDetector
17from detectors.timeDetector
import timeDetector
18from detectors.UpstreamTaggerDetector
import UpstreamTaggerDetector
20logger = logging.getLogger(__name__)
24 "convert FairSHiP MC hits / digitized hits to measurements"
26 def __init__(self, finput, fout, fgeo) -> None:
33 self.
recoTree = ROOT.TTree(
"ship_reco_sim",
"Digitization and Reconstruction")
36 if self.
sTree.GetBranch(
"GeoTracks"):
37 self.
sTree.SetBranchStatus(
"GeoTracks", 0)
56 if self.
sTree.GetBranch(
"MTCDetPoint"):
58 if self.
sTree.GetBranch(
"vetoPoint"):
62 if self.
sTree.GetBranch(
"TimeDetPoint"):
64 if self.
sTree.GetBranch(
"UpstreamTaggerPoint"):
68 self.
v_drift = global_variables.modules[
"strawtubes"].StrawVdrift()
69 self.
sigma_spatial = global_variables.modules[
"strawtubes"].StrawSigmaSpatial()
71 if self.
sTree.GetBranch(
"splitcalPoint"):
81 ROOT.gRandom.SetSeed(13)
82 self.
PDG = ROOT.TDatabasePDG.Instance()
84 self.
sTree.GetEvent(0)
87 self.
geoMat = ROOT.genfit.TGeoMaterialInterface()
89 self.
bfield = ROOT.genfit.FairShipFields()
90 self.
bfield.setField(global_variables.fieldMaker.getGlobalField())
91 self.
fM = ROOT.genfit.FieldManager.getInstance()
93 ROOT.genfit.MaterialEffects.getInstance().init(self.
geoMat)
99 self.
fitter.setMaxIterations(50)
100 if global_variables.debug:
101 self.
fitter.setDebugLvl(1)
110 if hasattr(self,
"digiSBT"):
112 if global_variables.vertexing:
117 self.
sTree.t0 = self.
random.Rndm() * 1 * u.microsecond
119 self.
header.SetRunId(self.
sTree.MCEventHeader.GetRunID())
120 self.
header.SetMCEntryNumber(self.
sTree.MCEventHeader.GetEventID())
121 if hasattr(self,
"digiSBT"):
124 if hasattr(self,
"timeDetector"):
126 if hasattr(self,
"upstreamTaggerDetector"):
128 if hasattr(self,
"digiMTC"):
130 if self.
sTree.GetBranch(
"splitcalPoint"):
135 hit_detector_ids = {}
136 stationCrossed: dict[int, dict[int, int]] = {}
137 listOfIndices: dict[int, list[int]] = {}
138 trackParams: dict[int, dict] = {}
144 if global_variables.withT0:
152 if global_variables.realPR:
155 logger.debug(
"PatRec returned %d track candidates", len(track_hits))
157 for i_track
in track_hits:
158 atrack = track_hits[i_track]
159 atrack_y12 = atrack[
"y12"]
160 atrack_stereo12 = atrack[
"stereo12"]
161 atrack_y34 = atrack[
"y34"]
162 atrack_stereo34 = atrack[
"stereo34"]
163 atrack_smeared_hits = (
164 list(atrack_y12) + list(atrack_stereo12) + list(atrack_y34) + list(atrack_stereo34)
167 trackParams[i_track] = {
168 "k_y12": atrack.get(
"k_y12"),
169 "b_y12": atrack.get(
"b_y12"),
170 "k_y34": atrack.get(
"k_y34"),
171 "b_y34": atrack.get(
"b_y34"),
173 for sm
in atrack_smeared_hits:
175 station = self.
strawtubes.det[sm[
"digiHit"]].GetStationNumber()
178 if trID
not in hitPosLists:
179 hitPosLists[trID] = ROOT.std.vector(
"TVectorD")()
180 listOfIndices[trID] = []
181 stationCrossed[trID] = {}
182 hit_detector_ids[trID] = ROOT.std.vector(
"int")()
183 hit_detector_ids[trID].push_back(detID)
184 m = array(
"d", [sm[
"xtop"], sm[
"ytop"], sm[
"z"], sm[
"xbot"], sm[
"ybot"], sm[
"z"], sm[
"dist"]])
185 hitPosLists[trID].push_back(ROOT.TVectorD(7, m))
186 listOfIndices[trID].append(sm[
"digiHit"])
187 if station
not in stationCrossed[trID]:
188 stationCrossed[trID][station] = 0
189 stationCrossed[trID][station] += 1
192 detID = self.
strawtubes.det[sm[
"digiHit"]].GetDetectorID()
193 station = self.
strawtubes.det[sm[
"digiHit"]].GetStationNumber()
194 trID = self.
sTree.strawtubesPoint[sm[
"digiHit"]].GetTrackID()
195 if trID
not in hitPosLists:
196 hitPosLists[trID] = ROOT.std.vector(
"TVectorD")()
197 listOfIndices[trID] = []
198 stationCrossed[trID] = {}
199 hit_detector_ids[trID] = ROOT.std.vector(
"int")()
200 hit_detector_ids[trID].push_back(detID)
201 m = array(
"d", [sm[
"xtop"], sm[
"ytop"], sm[
"z"], sm[
"xbot"], sm[
"ybot"], sm[
"z"], sm[
"dist"]])
202 hitPosLists[trID].push_back(ROOT.TVectorD(7, m))
203 listOfIndices[trID].append(sm[
"digiHit"])
204 if station
not in stationCrossed[trID]:
205 stationCrossed[trID][station] = 0
206 stationCrossed[trID][station] += 1
209 n_too_few_stations = 0
216 for atrack
in hitPosLists:
220 meas = hitPosLists[atrack]
221 detIDs = hit_detector_ids[atrack]
226 if len(stationCrossed[atrack]) < 3:
227 n_too_few_stations += 1
229 if global_variables.debug:
230 self.
sTree.MCTrack[atrack]
236 covM = ROOT.TMatrixDSym(6)
238 if global_variables.withT0:
241 covM[i][i] = resolution * resolution
242 covM[0][0] = resolution * resolution * 100.0
243 for i
in range(3, 6):
244 covM[i][i] = ROOT.TMath.Power(resolution / nM / ROOT.TMath.Sqrt(3), 2)
246 rep = ROOT.genfit.RKTrackRep(pdg)
248 stateSmeared = ROOT.genfit.MeasuredStateOnPlane(rep)
249 rep.setPosMomCov(stateSmeared, posM, momM, covM)
251 seedState = ROOT.TVectorD(6)
252 seedCov = ROOT.TMatrixDSym(6)
253 rep.get6DStateCov(stateSmeared, seedState, seedCov)
254 theTrack = ROOT.genfit.Track(rep, seedState, seedCov)
255 hitCov = ROOT.TMatrixDSym(7)
256 hitCov[6][6] = resolution * resolution
258 for m, detID
in zip(meas, detIDs):
259 tp = ROOT.genfit.TrackPoint(theTrack)
260 measurement = ROOT.genfit.WireMeasurement(
261 m, hitCov, detID, hitID, tp
263 measurement.setMaxDistance(
264 global_variables.ShipGeo.strawtubes_geo.outer_straw_diameter / 2.0
265 - global_variables.ShipGeo.strawtubes_geo.wall_thickness
267 tp.addRawMeasurement(measurement)
268 theTrack.insertPoint(tp)
270 trackCandidates.append([theTrack, atrack])
272 for entry
in trackCandidates:
277 theTrack.checkConsistency()
278 except ROOT.genfit.Exception
as e:
280 logger.warning(
"Problem with track before fit, not consistent %s %s", atrack, theTrack)
281 logger.warning(e.what())
285 self.
fitter.processTrack(theTrack)
288 if global_variables.debug:
289 print(
"genfit failed to fit track")
290 error =
"genfit failed to fit track"
291 ut.reportError(error)
295 theTrack.checkConsistency()
296 except ROOT.genfit.Exception
as e:
298 if global_variables.debug:
299 print(
"Problem with track after fit, not consistent", atrack, theTrack)
301 error =
"Problem with track after fit, not consistent"
302 ut.reportError(error)
304 fittedState = theTrack.getFittedState()
305 fittedState.getMomMag()
308 error =
"Problem with fittedstate"
309 ut.reportError(error)
311 fitStatus = theTrack.getFitStatus()
313 fitStatus.isFitConverged()
314 except ROOT.genfit.Exception:
315 error =
"Fit not converged"
316 ut.reportError(error)
317 nmeas = fitStatus.getNdf()
318 global_variables.h[
"nmeas"].Fill(nmeas)
322 chi2 = fitStatus.getChi2() / nmeas
323 global_variables.h[
"chi2"].Fill(chi2)
326 trackCopy = ROOT.genfit.Track(theTrack)
327 ROOT.SetOwnership(trackCopy,
False)
329 if global_variables.debug:
330 print(
"save track", theTrack, chi2, nmeas, fitStatus.isFitConverged())
333 for index
in listOfIndices[atrack]:
334 ahit = self.
sTree.strawtubesPoint[index]
335 track_ids += [ahit.GetTrackID()]
339 indices = ROOT.std.vector(
"unsigned int")()
340 for index
in listOfIndices[atrack]:
341 indices.push_back(index)
342 aTracklet = ROOT.Tracklet(1, indices)
346 "findTracks: %d candidates, %d too few hits, %d too few stations, "
347 "%d prefit fail, %d fit fail, %d postfit fail, %d no state, "
348 "%d no NDF, %d fitted tracks saved",
361 if global_variables.debug:
362 print(
"save tracklets:")
364 print(x.getType(), len(x.getList()))
368 """Compute seed position and momentum for the track fitter.
370 When PatRec track parameters (k_y, b_y) are available, use them
371 to place the seed at the first hit's Z with the correct Y position
372 and momentum direction. Otherwise fall back to the default seed
373 at the decay vessel centre.
375 params = trackParams.get(atrack)
376 if params
and params.get(
"k_y12")
is not None:
378 k_y = params[
"k_y12"]
379 b_y = params[
"b_y12"]
382 y_seed = k_y * z_seed + b_y
383 posM = ROOT.TVector3(0, y_seed, z_seed)
385 p_total = 3.0 * u.GeV
386 pz = p_total / ROOT.TMath.Sqrt(1.0 + k_y * k_y)
388 momM = ROOT.TVector3(0, py, pz)
390 "seed from PatRec: z=%.1f y=%.1f k_y=%.4f p=(0, %.2f, %.2f)",
398 posM = ROOT.TVector3(0, 0, 5812.0)
399 momM = ROOT.TVector3(0, 0, 3.0 * u.GeV)
406 fitStatus = track.getFitStatus()
407 if not fitStatus.isFitConverged():
409 nmeas = fitStatus.getNdf()
410 chi2 = fitStatus.getChi2() / nmeas
411 if chi2 < 50
and not chi2 < 0:
419 xx = track.getFittedState()
420 rep = ROOT.genfit.RKTrackRep(xx.getPDG())
421 state = ROOT.genfit.StateOnPlane(rep)
422 rep.setPosMom(state, xx.getPos(), xx.getMom())
423 for i, vetoHit
in enumerate(self.
digiSBT.det):
426 rep.extrapolateToPoint(state, vetoHitPos,
False)
428 error =
"shipDigiReco::findVetoHitOnTrack extrapolation did not worked"
429 ut.reportError(error)
430 if global_variables.debug:
433 dist = (rep.getPos(state) - vetoHitPos).Mag()
437 return ROOT.vetoHitOnTrack(hitID, distMin)
454 tmax = max(track, key=track.get)
460 frac = float(track[tmax]) / float(nh)
465 print(
"finished writing tree")
469 ut.writeHists(global_variables.h,
"recohists.root")
470 if global_variables.realPR:
def _compute_seed_state(self, atrack, meas, trackParams)
def findVetoHitOnTrack(self, track)
None linkVetoOnTracks(self)
def fracMCsame(self, trackids)
None __init__(self, finput, fout, fgeo)
def execute(smeared_hits, ship_geo, str method="")