FairShip
Loading...
Searching...
No Matches
tracking_benchmark.TrackingBenchmark Class Reference

Public Member Functions

None __init__ (self, str sim_file, str reco_file, str geo_file, float purity_cut=0.70, int min_hits=25, int min_stations=3)
 
dict[str, Any] compute_metrics (self)
 
None save_json (self, str output_path)
 
None save_histograms (self, str output_path)
 
None print_summary (self)
 
None __del__ (self)
 

Public Attributes

 purity_cut
 
 min_hits
 
 min_stations
 
 f_sim
 
 sim_tree
 
 f_reco
 
 reco_tree
 
 f_geo
 
 PDG
 
 metrics
 

Protected Member Functions

bool _is_reconstructible (self, int mc_track_id)
 
tuple[float, float, float, float] _get_ptruth_first (self, int mc_track_id)
 
tuple[float, float, float] _get_truth_pos_first (self, int mc_track_id)
 
tuple[float, float] _get_truth_slopes (self, int mc_track_id)
 
tuple[float, int] _fracMCsame (self, int reco_track_idx)
 

Protected Attributes

 _histos
 

Detailed Description

Compute tracking benchmark metrics from simulation and reconstruction files.

Parameters
----------
sim_file : str
    Path to MC simulation ROOT file (contains cbmsim tree).
reco_file : str
    Path to reconstruction ROOT file (contains ship_reco_sim tree).
geo_file : str
    Path to geometry ROOT file.
purity_cut : float
    Minimum hit purity fraction for a reco track to be considered matched.
min_hits : int
    Minimum number of straw hits for reconstructibility.
min_stations : int
    Minimum number of tracking stations crossed for reconstructibility.

Definition at line 47 of file tracking_benchmark.py.

Constructor & Destructor Documentation

◆ __init__()

None tracking_benchmark.TrackingBenchmark.__init__ (   self,
str  sim_file,
str  reco_file,
str  geo_file,
float   purity_cut = 0.70,
int   min_hits = 25,
int   min_stations = 3 
)

Definition at line 66 of file tracking_benchmark.py.

74 ) -> None:
75 self.purity_cut = purity_cut
76 self.min_hits = min_hits
77 self.min_stations = min_stations
78
79 self.f_sim = ROOT.TFile.Open(sim_file, "read")
80 self.sim_tree = self.f_sim["cbmsim"]
81
82 self.f_reco = ROOT.TFile.Open(reco_file, "read")
83 self.reco_tree = self.f_reco["ship_reco_sim"]
84
85 self.f_geo = ROOT.TFile.Open(geo_file, "read")
86
87 self.PDG = ROOT.TDatabasePDG.Instance()
88
89 self.metrics: dict[str, Any] = {}
90 self._histos: dict[str, Any] = {}
91

◆ __del__()

None tracking_benchmark.TrackingBenchmark.__del__ (   self)

Definition at line 412 of file tracking_benchmark.py.

412 def __del__(self) -> None:
413 for f in [self.f_sim, self.f_reco, self.f_geo]:
414 if f and f.IsOpen():
415 f.Close()

Member Function Documentation

◆ _fracMCsame()

tuple[float, int] tracking_benchmark.TrackingBenchmark._fracMCsame (   self,
int  reco_track_idx 
)
protected
Get the hit purity and dominant MC track ID for a reco track.

Uses the Tracklets branch to access hit indices, then checks
which MC track contributed most hits.

Definition at line 153 of file tracking_benchmark.py.

153 def _fracMCsame(self, reco_track_idx: int) -> tuple[float, int]:
154 """Get the hit purity and dominant MC track ID for a reco track.
155
156 Uses the Tracklets branch to access hit indices, then checks
157 which MC track contributed most hits.
158 """
159 tracklet = self.reco_tree.Tracklets[reco_track_idx]
160 hit_indices = tracklet.getList()
161
162 track_counts: dict[int, int] = {}
163 n_hits = 0
164 for idx in hit_indices:
165 mc_id = self.sim_tree.strawtubesPoint[idx].GetTrackID()
166 track_counts[mc_id] = track_counts.get(mc_id, 0) + 1
167 n_hits += 1
168
169 if not track_counts:
170 return 0.0, -999
171
172 tmax = max(track_counts, key=track_counts.__getitem__)
173 frac = track_counts[tmax] / n_hits if n_hits > 0 else 0.0
174 return frac, tmax
175

◆ _get_ptruth_first()

tuple[float, float, float, float] tracking_benchmark.TrackingBenchmark._get_ptruth_first (   self,
int  mc_track_id 
)
protected
Get MC truth momentum at the first straw hit.

Follows the pattern from macro/ShipAna.py:getPtruthFirst().

Definition at line 124 of file tracking_benchmark.py.

124 def _get_ptruth_first(self, mc_track_id: int) -> tuple[float, float, float, float]:
125 """Get MC truth momentum at the first straw hit.
126
127 Follows the pattern from macro/ShipAna.py:getPtruthFirst().
128 """
129 for hit in self.sim_tree.strawtubesPoint:
130 if hit.GetTrackID() == mc_track_id:
131 px, py, pz = hit.GetPx(), hit.GetPy(), hit.GetPz()
132 p = math.sqrt(px**2 + py**2 + pz**2)
133 return p, px, py, pz
134 return -1.0, -1.0, -1.0, -1.0
135

◆ _get_truth_pos_first()

tuple[float, float, float] tracking_benchmark.TrackingBenchmark._get_truth_pos_first (   self,
int  mc_track_id 
)
protected
Get MC truth position at the first straw hit.

Definition at line 136 of file tracking_benchmark.py.

136 def _get_truth_pos_first(self, mc_track_id: int) -> tuple[float, float, float]:
137 """Get MC truth position at the first straw hit."""
138 for hit in self.sim_tree.strawtubesPoint:
139 if hit.GetTrackID() == mc_track_id:
140 return hit.GetX(), hit.GetY(), hit.GetZ()
141 return 0.0, 0.0, 0.0
142

◆ _get_truth_slopes()

tuple[float, float] tracking_benchmark.TrackingBenchmark._get_truth_slopes (   self,
int  mc_track_id 
)
protected
Get MC truth track slopes tx=px/pz, ty=py/pz at first straw hit.

Definition at line 143 of file tracking_benchmark.py.

143 def _get_truth_slopes(self, mc_track_id: int) -> tuple[float, float]:
144 """Get MC truth track slopes tx=px/pz, ty=py/pz at first straw hit."""
145 for hit in self.sim_tree.strawtubesPoint:
146 if hit.GetTrackID() == mc_track_id:
147 px, py, pz = hit.GetPx(), hit.GetPy(), hit.GetPz()
148 if abs(pz) > 1e-10:
149 return px / pz, py / pz
150 return 0.0, 0.0
151 return 0.0, 0.0
152

◆ _is_reconstructible()

bool tracking_benchmark.TrackingBenchmark._is_reconstructible (   self,
int  mc_track_id 
)
protected
Check if an MC particle meets reconstructibility criteria.

A particle is reconstructible if it is a charged primary with
hits in >= min_stations tracking stations and >= min_hits total
straw hits. This matches the cuts in shipDigiReco.findTracks().

Definition at line 92 of file tracking_benchmark.py.

92 def _is_reconstructible(self, mc_track_id: int) -> bool:
93 """Check if an MC particle meets reconstructibility criteria.
94
95 A particle is reconstructible if it is a charged primary with
96 hits in >= min_stations tracking stations and >= min_hits total
97 straw hits. This matches the cuts in shipDigiReco.findTracks().
98 """
99 mc_track = self.sim_tree.MCTrack[mc_track_id]
100
101 # Must be primary (no mother)
102 if mc_track.GetMotherId() >= 0:
103 return False
104
105 # Must be charged
106 pdg_code = mc_track.GetPdgCode()
107 particle = self.PDG.GetParticle(pdg_code)
108 if particle is None or particle.Charge() == 0:
109 return False
110
111 # Count hits per station
112 stations: set[int] = set()
113 n_hits = 0
114 for hit in self.sim_tree.strawtubesPoint:
115 if hit.GetTrackID() != mc_track_id:
116 continue
117 n_hits += 1
118 det_id = hit.GetDetectorID()
119 station = int(det_id // 1_000_000)
120 stations.add(station)
121
122 return n_hits >= self.min_hits and len(stations) >= self.min_stations
123

◆ compute_metrics()

dict[str, Any] tracking_benchmark.TrackingBenchmark.compute_metrics (   self)
Run the full benchmark analysis over all events.

Returns
-------
dict
    Dictionary of metrics compatible with compare_metrics.py format.

Definition at line 176 of file tracking_benchmark.py.

176 def compute_metrics(self) -> dict[str, Any]:
177 """Run the full benchmark analysis over all events.
178
179 Returns
180 -------
181 dict
182 Dictionary of metrics compatible with compare_metrics.py format.
183 """
184 n_events = self.sim_tree.GetEntries()
185 n_reco_events = self.reco_tree.GetEntries()
186 if n_events != n_reco_events:
187 print(f"WARNING: sim has {n_events} events, reco has {n_reco_events}")
188 n_events = min(n_events, n_reco_events)
189
190 # Book histograms
191 h_dp_over_p = ROOT.TH1D("h_dp_over_p", "#Deltap/p;(p_{reco} - p_{truth})/p_{truth};Entries", 100, -0.5, 0.5)
192 h_dp_vs_p = ROOT.TH2D(
193 "h_dp_vs_p", "#Deltap/p vs p_{truth};p_{truth} [GeV/c];#Deltap/p", 50, 0, 120, 100, -0.5, 0.5
194 )
195 h_dx = ROOT.TH1D("h_dx", "#Deltax at first hit;x_{reco} - x_{truth} [cm];Entries", 100, -5.0, 5.0)
196 h_dy = ROOT.TH1D("h_dy", "#Deltay at first hit;y_{reco} - y_{truth} [cm];Entries", 100, -5.0, 5.0)
197 h_dtx = ROOT.TH1D("h_dtx", "#Deltat_{x};t_{x,reco} - t_{x,truth};Entries", 100, -0.01, 0.01)
198 h_dty = ROOT.TH1D("h_dty", "#Deltat_{y};t_{y,reco} - t_{y,truth};Entries", 100, -0.01, 0.01)
199 h_chi2ndf = ROOT.TH1D("h_chi2ndf", "#chi^{2}/ndf;#chi^{2}/ndf;Entries", 100, 0, 20)
200 h_p_truth = ROOT.TH1D("h_p_truth", "p_{truth} (reconstructible);p [GeV/c];Entries", 50, 0, 120)
201 h_p_matched = ROOT.TH1D("h_p_matched", "p_{truth} (matched);p [GeV/c];Entries", 50, 0, 120)
202
203 # Counters
204 n_reconstructible = 0
205 n_matched_mc = 0 # reconstructible MC particles with >= 1 matched reco track
206 n_total_reco = 0
207 n_matched_reco = 0 # reco tracks passing purity cut
208 n_clone_reco = 0 # extra matches beyond the first for same MC particle
209
210 for i_event in range(n_events):
211 self.sim_tree.GetEvent(i_event)
212 self.reco_tree.GetEvent(i_event)
213
214 # Find reconstructible MC particles
215 reconstructible_ids: set[int] = set()
216 n_mc_tracks = len(self.sim_tree.MCTrack)
217 for mc_id in range(n_mc_tracks):
218 if self._is_reconstructible(mc_id):
219 reconstructible_ids.add(mc_id)
220 p_truth, _, _, _ = self._get_ptruth_first(mc_id)
221 if p_truth > 0:
222 h_p_truth.Fill(p_truth)
223
224 n_reconstructible += len(reconstructible_ids)
225
226 # Match reco tracks to MC
227 n_reco = self.reco_tree.FitTracks.size()
228 n_total_reco += n_reco
229
230 # Track which MC particles have been matched in this event
231 matched_mc_this_event: set[int] = set()
232
233 for i_reco in range(n_reco):
234 track = self.reco_tree.FitTracks[i_reco]
235 fit_status = track.getFitStatus()
236 if not fit_status.isFitConverged():
237 continue
238
239 ndf = fit_status.getNdf()
240 if ndf <= 0:
241 continue
242 chi2 = fit_status.getChi2() / ndf
243 h_chi2ndf.Fill(chi2)
244
245 # Use fitTrack2MC for the MC link (already computed by fracMCsame)
246 mc_id = self.reco_tree.fitTrack2MC[i_reco]
247
248 # Recompute purity to apply our cut
249 frac, _dominant_id = self._fracMCsame(i_reco)
250
251 if frac < self.purity_cut:
252 # Ghost track
253 continue
254
255 n_matched_reco += 1
256
257 if mc_id in reconstructible_ids:
258 if mc_id not in matched_mc_this_event:
259 matched_mc_this_event.add(mc_id)
260 else:
261 n_clone_reco += 1
262
263 # Resolution histograms (use first match only for resolution)
264 p_truth, _, _, _ = self._get_ptruth_first(mc_id)
265 x_t, y_t, _ = self._get_truth_pos_first(mc_id)
266 tx_t, ty_t = self._get_truth_slopes(mc_id)
267
268 if p_truth > 0:
269 try:
270 fitted_state = track.getFittedState()
271 p_reco = fitted_state.getMomMag()
272 mom = fitted_state.getMom()
273 pos = fitted_state.getPos()
274
275 dp_over_p = (p_reco - p_truth) / p_truth
276 h_dp_over_p.Fill(dp_over_p)
277 h_dp_vs_p.Fill(p_truth, dp_over_p)
278
279 h_dx.Fill(pos.X() - x_t)
280 h_dy.Fill(pos.Y() - y_t)
281
282 pz_reco = mom.Z()
283 if abs(pz_reco) > 1e-10:
284 tx_reco = mom.X() / pz_reco
285 ty_reco = mom.Y() / pz_reco
286 h_dtx.Fill(tx_reco - tx_t)
287 h_dty.Fill(ty_reco - ty_t)
288
289 h_p_matched.Fill(p_truth)
290 except Exception:
291 pass
292
293 n_matched_mc += len(matched_mc_this_event)
294
295 # Compute metrics
296 n_ghost_reco = n_total_reco - n_matched_reco
297
298 efficiency = n_matched_mc / n_reconstructible if n_reconstructible > 0 else 0.0
299 efficiency_unc = wilson_interval(n_matched_mc, n_reconstructible)
300
301 clone_rate = n_clone_reco / n_matched_reco if n_matched_reco > 0 else 0.0
302 clone_rate_unc = wilson_interval(n_clone_reco, n_matched_reco)
303
304 ghost_rate = n_ghost_reco / n_total_reco if n_total_reco > 0 else 0.0
305 ghost_rate_unc = wilson_interval(n_ghost_reco, n_total_reco)
306
307 # Fit dp/p with Gaussian
308 dp_p_sigma = h_dp_over_p.GetRMS()
309 dp_p_sigma_unc = h_dp_over_p.GetRMSError()
310 if h_dp_over_p.GetEntries() > 20:
311 fit_result = h_dp_over_p.Fit("gaus", "QS")
312 if fit_result and int(fit_result) == 0:
313 dp_p_sigma = fit_result.Parameter(2)
314 dp_p_sigma_unc = fit_result.ParError(2)
315
316 self.metrics = {
317 "tracking_benchmark": {
318 "n_events": {"value": int(n_events), "compare": "exact"},
319 "n_reconstructible": {"value": int(n_reconstructible), "compare": "exact"},
320 "n_total_reco": {"value": int(n_total_reco), "compare": "exact"},
321 "efficiency": {
322 "value": round(efficiency, 6),
323 "uncertainty": round(efficiency_unc, 6),
324 "compare": "statistical",
325 },
326 "clone_rate": {
327 "value": round(clone_rate, 6),
328 "uncertainty": round(clone_rate_unc, 6),
329 "compare": "statistical",
330 },
331 "ghost_rate": {
332 "value": round(ghost_rate, 6),
333 "uncertainty": round(ghost_rate_unc, 6),
334 "compare": "statistical",
335 },
336 "dp_over_p_sigma": {
337 "value": round(dp_p_sigma, 6),
338 "uncertainty": round(dp_p_sigma_unc, 6),
339 "compare": "statistical",
340 },
341 "dx_rms": {
342 "value": round(h_dx.GetRMS(), 6),
343 "uncertainty": round(h_dx.GetRMSError(), 6),
344 "compare": "statistical",
345 },
346 "dy_rms": {
347 "value": round(h_dy.GetRMS(), 6),
348 "uncertainty": round(h_dy.GetRMSError(), 6),
349 "compare": "statistical",
350 },
351 "dtx_rms": {
352 "value": round(h_dtx.GetRMS(), 6),
353 "uncertainty": round(h_dtx.GetRMSError(), 6),
354 "compare": "statistical",
355 },
356 "dty_rms": {
357 "value": round(h_dty.GetRMS(), 6),
358 "uncertainty": round(h_dty.GetRMSError(), 6),
359 "compare": "statistical",
360 },
361 }
362 }
363
364 self._histos = {
365 "h_dp_over_p": h_dp_over_p,
366 "h_dp_vs_p": h_dp_vs_p,
367 "h_dx": h_dx,
368 "h_dy": h_dy,
369 "h_dtx": h_dtx,
370 "h_dty": h_dty,
371 "h_chi2ndf": h_chi2ndf,
372 "h_p_truth": h_p_truth,
373 "h_p_matched": h_p_matched,
374 }
375
376 return self.metrics
377

◆ print_summary()

None tracking_benchmark.TrackingBenchmark.print_summary (   self)
Print a human-readable summary of the benchmark results.

Definition at line 392 of file tracking_benchmark.py.

392 def print_summary(self) -> None:
393 """Print a human-readable summary of the benchmark results."""
394 if not self.metrics:
395 print("No metrics computed yet. Call compute_metrics() first.")
396 return
397 m = self.metrics["tracking_benchmark"]
398 print("\n=== Tracking Benchmark Summary ===")
399 print(f" Events: {m['n_events']['value']}")
400 print(f" Reconstructible: {m['n_reconstructible']['value']}")
401 print(f" Total reco: {m['n_total_reco']['value']}")
402 print(f" Efficiency: {m['efficiency']['value']:.4f} +/- {m['efficiency']['uncertainty']:.4f}")
403 print(f" Clone rate: {m['clone_rate']['value']:.4f} +/- {m['clone_rate']['uncertainty']:.4f}")
404 print(f" Ghost rate: {m['ghost_rate']['value']:.4f} +/- {m['ghost_rate']['uncertainty']:.4f}")
405 print(f" dp/p sigma: {m['dp_over_p_sigma']['value']:.6f} +/- {m['dp_over_p_sigma']['uncertainty']:.6f}")
406 print(f" dx RMS: {m['dx_rms']['value']:.4f} +/- {m['dx_rms']['uncertainty']:.4f} cm")
407 print(f" dy RMS: {m['dy_rms']['value']:.4f} +/- {m['dy_rms']['uncertainty']:.4f} cm")
408 print(f" dtx RMS: {m['dtx_rms']['value']:.6f} +/- {m['dtx_rms']['uncertainty']:.6f}")
409 print(f" dty RMS: {m['dty_rms']['value']:.6f} +/- {m['dty_rms']['uncertainty']:.6f}")
410 print("==================================\n")
411

◆ save_histograms()

None tracking_benchmark.TrackingBenchmark.save_histograms (   self,
str  output_path 
)
Save detailed histograms to a ROOT file.

Definition at line 384 of file tracking_benchmark.py.

384 def save_histograms(self, output_path: str) -> None:
385 """Save detailed histograms to a ROOT file."""
386 f_out = ROOT.TFile.Open(output_path, "recreate")
387 for name, hist in self._histos.items():
388 hist.Write(name)
389 f_out.Close()
390 print(f"Histograms saved to {output_path}")
391

◆ save_json()

None tracking_benchmark.TrackingBenchmark.save_json (   self,
str  output_path 
)
Save metrics to JSON file.

Definition at line 378 of file tracking_benchmark.py.

378 def save_json(self, output_path: str) -> None:
379 """Save metrics to JSON file."""
380 with open(output_path, "w") as f:
381 json.dump(self.metrics, f, indent=2)
382 print(f"Metrics saved to {output_path}")
383
int open(const char *, int)
Opens a file descriptor.

Member Data Documentation

◆ _histos

tracking_benchmark.TrackingBenchmark._histos
protected

Definition at line 364 of file tracking_benchmark.py.

◆ f_geo

tracking_benchmark.TrackingBenchmark.f_geo

Definition at line 85 of file tracking_benchmark.py.

◆ f_reco

tracking_benchmark.TrackingBenchmark.f_reco

Definition at line 82 of file tracking_benchmark.py.

◆ f_sim

tracking_benchmark.TrackingBenchmark.f_sim

Definition at line 79 of file tracking_benchmark.py.

◆ metrics

tracking_benchmark.TrackingBenchmark.metrics

Definition at line 316 of file tracking_benchmark.py.

◆ min_hits

tracking_benchmark.TrackingBenchmark.min_hits

Definition at line 76 of file tracking_benchmark.py.

◆ min_stations

tracking_benchmark.TrackingBenchmark.min_stations

Definition at line 77 of file tracking_benchmark.py.

◆ PDG

tracking_benchmark.TrackingBenchmark.PDG

Definition at line 87 of file tracking_benchmark.py.

◆ purity_cut

tracking_benchmark.TrackingBenchmark.purity_cut

Definition at line 75 of file tracking_benchmark.py.

◆ reco_tree

tracking_benchmark.TrackingBenchmark.reco_tree

Definition at line 83 of file tracking_benchmark.py.

◆ sim_tree

tracking_benchmark.TrackingBenchmark.sim_tree

Definition at line 80 of file tracking_benchmark.py.


The documentation for this class was generated from the following file: