FairShip
Loading...
Searching...
No Matches
shipDigiReco.py
Go to the documentation of this file.
1# SPDX-License-Identifier: LGPL-3.0-or-later
2# SPDX-FileCopyrightText: Copyright CERN for the benefit of the SHiP Collaboration
3
4import logging
5from array import array
6
7import global_variables
8import ROOT
9import rootUtils as ut
10import shipPatRec
11import shipunit as u
12import shipVertex
13from detectors.MTCDetector import MTCDetector
14from detectors.SBTDetector import SBTDetector
15from detectors.splitcalDetector import splitcalDetector
16from detectors.strawtubesDetector import strawtubesDetector
17from detectors.timeDetector import timeDetector
18from detectors.UpstreamTaggerDetector import UpstreamTaggerDetector
19
20logger = logging.getLogger(__name__)
21
22
24 "convert FairSHiP MC hits / digitized hits to measurements"
25
26 def __init__(self, finput, fout, fgeo) -> None:
27 # Open input file (read-only) and get the MC tree
28 self.inputFile = ROOT.TFile.Open(finput, "read")
29 self.sTree = self.inputFile["cbmsim"]
30
31 # Create output file and new tree for digi/reco branches only
32 self.outputFile = ROOT.TFile.Open(fout, "recreate")
33 self.recoTree = ROOT.TTree("ship_reco_sim", "Digitization and Reconstruction")
34
35 # Disable GeoTracks branch if present in input
36 if self.sTree.GetBranch("GeoTracks"):
37 self.sTree.SetBranchStatus("GeoTracks", 0)
38 # prepare for output
39 # event header
40 self.header = ROOT.FairEventHeader()
41 self.eventHeader = self.recoTree.Branch("ShipEventHeader", self.header, 32000, -1)
42 # fitted tracks
43 # Must use pointer storage: genfit::Track has circular references with TrackPoint
44 # requiring stable memory addresses (value storage would invalidate back-pointers on vector resize)
45 self.fGenFitArray = ROOT.std.vector("genfit::Track*")()
46 self.fitTrack2MC = ROOT.std.vector("int")()
47 self.goodTracksVect = ROOT.std.vector("int")()
48 self.mcLink = self.recoTree.Branch("fitTrack2MC", self.fitTrack2MC, 32000, -1)
49 self.fitTracks = self.recoTree.Branch("FitTracks", self.fGenFitArray, 32000, -1)
50 self.goodTracksBranch = self.recoTree.Branch("goodTracks", self.goodTracksVect, 32000, -1)
51 self.fTrackletsArray = ROOT.std.vector("Tracklet")()
52 self.Tracklets = self.recoTree.Branch("Tracklets", self.fTrackletsArray, 32000, -1)
53 #
54 self.strawtubes = strawtubesDetector("strawtubes", self.sTree, outtree=self.recoTree)
55
56 if self.sTree.GetBranch("MTCDetPoint"):
57 self.digiMTC = MTCDetector("MTCDet", self.sTree, "MTC", outtree=self.recoTree)
58 if self.sTree.GetBranch("vetoPoint"):
59 self.digiSBT = SBTDetector("veto", self.sTree, "SBT", mcBranchName="digiSBT2MC", outtree=self.recoTree)
60 self.vetoHitOnTrackArray = ROOT.std.vector("vetoHitOnTrack")()
61 self.vetoHitOnTrackBranch = self.recoTree.Branch("VetoHitOnTrack", self.vetoHitOnTrackArray)
62 if self.sTree.GetBranch("TimeDetPoint"):
63 self.timeDetector = timeDetector("TimeDet", self.sTree, outtree=self.recoTree)
64 if self.sTree.GetBranch("UpstreamTaggerPoint"):
65 self.upstreamTaggerDetector = UpstreamTaggerDetector("UpstreamTagger", self.sTree, outtree=self.recoTree)
66
67 # for the digitizing step
68 self.v_drift = global_variables.modules["strawtubes"].StrawVdrift()
69 self.sigma_spatial = global_variables.modules["strawtubes"].StrawSigmaSpatial()
70 # optional if present, splitcalCluster
71 if self.sTree.GetBranch("splitcalPoint"):
72 self.splitcalDetector = splitcalDetector("splitcal", self.sTree, outtree=self.recoTree)
73 # Keep references for backward compatibility
76
77 # prepare vertexing
78 self.Vertexing = shipVertex.Task(global_variables.h, self.recoTree, self.sTree)
79 # setup random number generator
80 self.random = ROOT.TRandom()
81 ROOT.gRandom.SetSeed(13)
82 self.PDG = ROOT.TDatabasePDG.Instance()
83 # access ShipTree
84 self.sTree.GetEvent(0)
85 #
86 # init geometry and mag. field
87 self.geoMat = ROOT.genfit.TGeoMaterialInterface()
88 #
89 self.bfield = ROOT.genfit.FairShipFields()
90 self.bfield.setField(global_variables.fieldMaker.getGlobalField())
91 self.fM = ROOT.genfit.FieldManager.getInstance()
92 self.fM.init(self.bfield)
93 ROOT.genfit.MaterialEffects.getInstance().init(self.geoMat)
94
95 # init fitter, to be done before importing shipPatRec
96 # fitter = ROOT.genfit.KalmanFitter()
97 # fitter = ROOT.genfit.KalmanFitterRefTrack()
98 self.fitter = ROOT.genfit.DAF()
99 self.fitter.setMaxIterations(50)
100 if global_variables.debug:
101 self.fitter.setDebugLvl(1) # produces lot of printout
102 # set to True if "real" pattern recognition is required also
103
104 # for 'real' PatRec
106
107 def reconstruct(self) -> None:
108 self.findTracks()
109 self.findGoodTracks()
110 if hasattr(self, "digiSBT"):
111 self.linkVetoOnTracks()
112 if global_variables.vertexing:
113 # now go for 2-track combinations
114 self.Vertexing.execute()
115
116 def digitize(self) -> None:
117 self.sTree.t0 = self.random.Rndm() * 1 * u.microsecond
118 self.header.SetEventTime(self.sTree.t0)
119 self.header.SetRunId(self.sTree.MCEventHeader.GetRunID())
120 self.header.SetMCEntryNumber(self.sTree.MCEventHeader.GetEventID()) # counts from 1
121 if hasattr(self, "digiSBT"):
122 self.digiSBT.process()
123 self.strawtubes.process()
124 if hasattr(self, "timeDetector"):
125 self.timeDetector.process()
126 if hasattr(self, "upstreamTaggerDetector"):
127 self.upstreamTaggerDetector.process()
128 if hasattr(self, "digiMTC"):
129 self.digiMTC.process()
130 if self.sTree.GetBranch("splitcalPoint"):
131 self.splitcalDetector.process()
132
133 def findTracks(self) -> int:
134 hitPosLists = {}
135 hit_detector_ids = {}
136 stationCrossed: dict[int, dict[int, int]] = {}
137 listOfIndices: dict[int, list[int]] = {}
138 trackParams: dict[int, dict] = {}
139 self.fGenFitArray.clear()
140 self.fTrackletsArray.clear()
141 self.fitTrack2MC.clear()
142
143 #
144 if global_variables.withT0:
145 self.SmearedHits = self.strawtubes.withT0Estimate()
146 # old procedure, not including estimation of t0
147 else:
148 self.SmearedHits = self.strawtubes.smearHits(global_variables.withNoStrawSmearing)
149
150 trackCandidates = []
151
152 if global_variables.realPR:
153 # Do real PatRec
154 track_hits = shipPatRec.execute(self.SmearedHits, global_variables.ShipGeo, global_variables.realPR)
155 logger.debug("PatRec returned %d track candidates", len(track_hits))
156 # Create hitPosLists for track fit
157 for i_track in track_hits:
158 atrack = track_hits[i_track]
159 atrack_y12 = atrack["y12"]
160 atrack_stereo12 = atrack["stereo12"]
161 atrack_y34 = atrack["y34"]
162 atrack_stereo34 = atrack["stereo34"]
163 atrack_smeared_hits = (
164 list(atrack_y12) + list(atrack_stereo12) + list(atrack_y34) + list(atrack_stereo34)
165 )
166 # Store PatRec track parameters for seeding the fitter
167 trackParams[i_track] = {
168 "k_y12": atrack.get("k_y12"),
169 "b_y12": atrack.get("b_y12"),
170 "k_y34": atrack.get("k_y34"),
171 "b_y34": atrack.get("b_y34"),
172 }
173 for sm in atrack_smeared_hits:
174 detID = sm["detID"]
175 station = self.strawtubes.det[sm["digiHit"]].GetStationNumber()
176 trID = i_track
177 # Collect hits for track fit
178 if trID not in hitPosLists:
179 hitPosLists[trID] = ROOT.std.vector("TVectorD")()
180 listOfIndices[trID] = []
181 stationCrossed[trID] = {}
182 hit_detector_ids[trID] = ROOT.std.vector("int")()
183 hit_detector_ids[trID].push_back(detID)
184 m = array("d", [sm["xtop"], sm["ytop"], sm["z"], sm["xbot"], sm["ybot"], sm["z"], sm["dist"]])
185 hitPosLists[trID].push_back(ROOT.TVectorD(7, m))
186 listOfIndices[trID].append(sm["digiHit"])
187 if station not in stationCrossed[trID]:
188 stationCrossed[trID][station] = 0
189 stationCrossed[trID][station] += 1
190 else: # do fake pattern recognition
191 for sm in self.SmearedHits:
192 detID = self.strawtubes.det[sm["digiHit"]].GetDetectorID()
193 station = self.strawtubes.det[sm["digiHit"]].GetStationNumber()
194 trID = self.sTree.strawtubesPoint[sm["digiHit"]].GetTrackID()
195 if trID not in hitPosLists:
196 hitPosLists[trID] = ROOT.std.vector("TVectorD")()
197 listOfIndices[trID] = []
198 stationCrossed[trID] = {}
199 hit_detector_ids[trID] = ROOT.std.vector("int")()
200 hit_detector_ids[trID].push_back(detID)
201 m = array("d", [sm["xtop"], sm["ytop"], sm["z"], sm["xbot"], sm["ybot"], sm["z"], sm["dist"]])
202 hitPosLists[trID].push_back(ROOT.TVectorD(7, m))
203 listOfIndices[trID].append(sm["digiHit"])
204 if station not in stationCrossed[trID]:
205 stationCrossed[trID][station] = 0
206 stationCrossed[trID][station] += 1
207
208 n_too_few_hits = 0
209 n_too_few_stations = 0
210 n_prefit_fail = 0
211 n_fit_fail = 0
212 n_postfit_fail = 0
213 n_no_state = 0
214 n_no_ndf = 0
215
216 for atrack in hitPosLists:
217 if atrack < 0:
218 continue # these are hits not assigned to MC track because low E cut
219 pdg = 13 # assume all tracks are muons
220 meas = hitPosLists[atrack]
221 detIDs = hit_detector_ids[atrack]
222 nM = len(meas)
223 if nM < 13:
224 n_too_few_hits += 1
225 continue # not enough hits to make a good trackfit
226 if len(stationCrossed[atrack]) < 3:
227 n_too_few_stations += 1
228 continue # not enough stations crossed to make a good trackfit
229 if global_variables.debug:
230 self.sTree.MCTrack[atrack]
231
232 # Seed state: use PatRec track parameters when available
233 posM, momM = self._compute_seed_state(atrack, meas, trackParams)
234
235 # approximate covariance
236 covM = ROOT.TMatrixDSym(6)
237 resolution = self.sigma_spatial
238 if global_variables.withT0:
239 resolution *= 1.4 # worse resolution due to t0 estimate
240 for i in range(3):
241 covM[i][i] = resolution * resolution
242 covM[0][0] = resolution * resolution * 100.0
243 for i in range(3, 6):
244 covM[i][i] = ROOT.TMath.Power(resolution / nM / ROOT.TMath.Sqrt(3), 2)
245 # trackrep
246 rep = ROOT.genfit.RKTrackRep(pdg)
247 # smeared start state
248 stateSmeared = ROOT.genfit.MeasuredStateOnPlane(rep)
249 rep.setPosMomCov(stateSmeared, posM, momM, covM)
250 # create track
251 seedState = ROOT.TVectorD(6)
252 seedCov = ROOT.TMatrixDSym(6)
253 rep.get6DStateCov(stateSmeared, seedState, seedCov)
254 theTrack = ROOT.genfit.Track(rep, seedState, seedCov)
255 hitCov = ROOT.TMatrixDSym(7)
256 hitCov[6][6] = resolution * resolution
257 hitID = 0
258 for m, detID in zip(meas, detIDs):
259 tp = ROOT.genfit.TrackPoint(theTrack) # note how the point is told which track it belongs to
260 measurement = ROOT.genfit.WireMeasurement(
261 m, hitCov, detID, hitID, tp
262 ) # the measurement is told which trackpoint it belongs to
263 measurement.setMaxDistance(
264 global_variables.ShipGeo.strawtubes_geo.outer_straw_diameter / 2.0
265 - global_variables.ShipGeo.strawtubes_geo.wall_thickness
266 )
267 tp.addRawMeasurement(measurement) # package measurement in the TrackPoint
268 theTrack.insertPoint(tp) # add point to Track
269 hitID += 1
270 trackCandidates.append([theTrack, atrack])
271
272 for entry in trackCandidates:
273 # check
274 atrack = entry[1]
275 theTrack = entry[0]
276 try:
277 theTrack.checkConsistency()
278 except ROOT.genfit.Exception as e:
279 n_prefit_fail += 1
280 logger.warning("Problem with track before fit, not consistent %s %s", atrack, theTrack)
281 logger.warning(e.what())
282 ut.reportError(e)
283 # do the fit
284 try:
285 self.fitter.processTrack(theTrack) # processTrackWithRep(theTrack,rep,True)
286 except Exception:
287 n_fit_fail += 1
288 if global_variables.debug:
289 print("genfit failed to fit track")
290 error = "genfit failed to fit track"
291 ut.reportError(error)
292 continue
293 # check
294 try:
295 theTrack.checkConsistency()
296 except ROOT.genfit.Exception as e:
297 n_postfit_fail += 1
298 if global_variables.debug:
299 print("Problem with track after fit, not consistent", atrack, theTrack)
300 print(e.what())
301 error = "Problem with track after fit, not consistent"
302 ut.reportError(error)
303 try:
304 fittedState = theTrack.getFittedState()
305 fittedState.getMomMag()
306 except Exception:
307 n_no_state += 1
308 error = "Problem with fittedstate"
309 ut.reportError(error)
310 continue
311 fitStatus = theTrack.getFitStatus()
312 try:
313 fitStatus.isFitConverged()
314 except ROOT.genfit.Exception:
315 error = "Fit not converged"
316 ut.reportError(error)
317 nmeas = fitStatus.getNdf()
318 global_variables.h["nmeas"].Fill(nmeas)
319 if nmeas <= 0:
320 n_no_ndf += 1
321 continue
322 chi2 = fitStatus.getChi2() / nmeas
323 global_variables.h["chi2"].Fill(chi2)
324 # make track persistent
325 # Store pointer - make a copy and let ROOT manage lifetime
326 trackCopy = ROOT.genfit.Track(theTrack)
327 ROOT.SetOwnership(trackCopy, False) # ROOT TTree owns the track
328 self.fGenFitArray.push_back(trackCopy)
329 if global_variables.debug:
330 print("save track", theTrack, chi2, nmeas, fitStatus.isFitConverged())
331 # Save MC link
332 track_ids = []
333 for index in listOfIndices[atrack]:
334 ahit = self.sTree.strawtubesPoint[index]
335 track_ids += [ahit.GetTrackID()]
336 _frac, tmax = self.fracMCsame(track_ids)
337 self.fitTrack2MC.push_back(tmax)
338 # Save hits indexes of the the fitted tracks
339 indices = ROOT.std.vector("unsigned int")()
340 for index in listOfIndices[atrack]:
341 indices.push_back(index)
342 aTracklet = ROOT.Tracklet(1, indices)
343 self.fTrackletsArray.push_back(aTracklet)
344
345 logger.debug(
346 "findTracks: %d candidates, %d too few hits, %d too few stations, "
347 "%d prefit fail, %d fit fail, %d postfit fail, %d no state, "
348 "%d no NDF, %d fitted tracks saved",
349 len(hitPosLists),
350 n_too_few_hits,
351 n_too_few_stations,
352 n_prefit_fail,
353 n_fit_fail,
354 n_postfit_fail,
355 n_no_state,
356 n_no_ndf,
357 len(self.fGenFitArray),
358 )
359
360 # debug
361 if global_variables.debug:
362 print("save tracklets:")
363 for x in self.recoTree.Tracklets:
364 print(x.getType(), len(x.getList()))
365 return len(self.fGenFitArray)
366
367 def _compute_seed_state(self, atrack, meas, trackParams):
368 """Compute seed position and momentum for the track fitter.
369
370 When PatRec track parameters (k_y, b_y) are available, use them
371 to place the seed at the first hit's Z with the correct Y position
372 and momentum direction. Otherwise fall back to the default seed
373 at the decay vessel centre.
374 """
375 params = trackParams.get(atrack)
376 if params and params.get("k_y12") is not None:
377 # Use station 1-2 parameters to seed near the first measurement
378 k_y = params["k_y12"]
379 b_y = params["b_y12"]
380 # Seed Z at the first measurement's Z coordinate
381 z_seed = meas[0][2] # z is the 3rd element of the TVectorD
382 y_seed = k_y * z_seed + b_y
383 posM = ROOT.TVector3(0, y_seed, z_seed)
384 # Use slope as py/pz ratio; assume 3 GeV total momentum
385 p_total = 3.0 * u.GeV
386 pz = p_total / ROOT.TMath.Sqrt(1.0 + k_y * k_y)
387 py = k_y * pz
388 momM = ROOT.TVector3(0, py, pz)
389 logger.debug(
390 "seed from PatRec: z=%.1f y=%.1f k_y=%.4f p=(0, %.2f, %.2f)",
391 z_seed,
392 y_seed,
393 k_y,
394 py,
395 pz,
396 )
397 else:
398 posM = ROOT.TVector3(0, 0, 5812.0) # decay vessel centre
399 momM = ROOT.TVector3(0, 0, 3.0 * u.GeV)
400 return posM, momM
401
402 def findGoodTracks(self) -> int:
403 self.goodTracksVect.clear()
404 nGoodTracks = 0
405 for i, track in enumerate(self.fGenFitArray):
406 fitStatus = track.getFitStatus()
407 if not fitStatus.isFitConverged():
408 continue
409 nmeas = fitStatus.getNdf()
410 chi2 = fitStatus.getChi2() / nmeas
411 if chi2 < 50 and not chi2 < 0:
412 self.goodTracksVect.push_back(i)
413 nGoodTracks += 1
414 return nGoodTracks
415
416 def findVetoHitOnTrack(self, track):
417 distMin = 99999.0
418 hitID = -1
419 xx = track.getFittedState()
420 rep = ROOT.genfit.RKTrackRep(xx.getPDG())
421 state = ROOT.genfit.StateOnPlane(rep)
422 rep.setPosMom(state, xx.getPos(), xx.getMom())
423 for i, vetoHit in enumerate(self.digiSBT.det):
424 vetoHitPos = vetoHit.GetXYZ()
425 try:
426 rep.extrapolateToPoint(state, vetoHitPos, False)
427 except Exception:
428 error = "shipDigiReco::findVetoHitOnTrack extrapolation did not worked"
429 ut.reportError(error)
430 if global_variables.debug:
431 print(error)
432 continue
433 dist = (rep.getPos(state) - vetoHitPos).Mag()
434 if dist < distMin:
435 distMin = dist
436 hitID = i
437 return ROOT.vetoHitOnTrack(hitID, distMin)
438
439 def linkVetoOnTracks(self) -> None:
440 self.vetoHitOnTrackArray.clear()
441 for good_track in self.goodTracksVect:
442 track = self.fGenFitArray[good_track]
443 self.vetoHitOnTrackArray.push_back(self.findVetoHitOnTrack(track))
444
445 def fracMCsame(self, trackids):
446 track = {}
447 nh = len(trackids)
448 for tid in trackids:
449 if tid in track:
450 track[tid] += 1
451 else:
452 track[tid] = 1
453 if track != {}:
454 tmax = max(track, key=track.get)
455 else:
456 track = {-999: 0}
457 tmax = -999
458 frac = 0.0
459 if nh > 0:
460 frac = float(track[tmax]) / float(nh)
461 return frac, tmax
462
463 def finish(self) -> None:
464 del self.fitter
465 print("finished writing tree")
466 self.outputFile.cd()
467 self.recoTree.Write()
468 ut.errorSummary()
469 ut.writeHists(global_variables.h, "recohists.root")
470 if global_variables.realPR:
472 self.outputFile.Close()
473 self.inputFile.Close()
def _compute_seed_state(self, atrack, meas, trackParams)
def findVetoHitOnTrack(self, track)
def fracMCsame(self, trackids)
None __init__(self, finput, fout, fgeo)
Definition: shipDigiReco.py:26
TVector3 GetXYZ() const
Definition: vetoHit.cxx:32
def execute(smeared_hits, ship_geo, str method="")
Definition: shipPatRec.py:25
None finalize()
Definition: shipPatRec.py:58
None initialize(fgeo)
Definition: shipPatRec.py:21