FairShip
Loading...
Searching...
No Matches
tracking_benchmark.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2# SPDX-License-Identifier: LGPL-3.0-or-later
3# SPDX-FileCopyrightText: Copyright CERN for the benefit of the SHiP Collaboration
4
5"""Tracking performance benchmark metrics for straw tube spectrometer.
6
7Computes track finding efficiency, clone rate, ghost rate, and resolution
8metrics by comparing MC truth with reconstructed tracks. Designed to
9establish a GenFit baseline and later measure ACTS performance.
10"""
11
12from __future__ import annotations
13
14import json
15import math
16from typing import Any
17
18import ROOT
19
20ROOT.gROOT.SetBatch(True)
21
22
23def wilson_interval(k: int, n: int) -> float:
24 """Wilson score interval half-width for a binomial proportion.
25
26 Parameters
27 ----------
28 k : int
29 Number of successes.
30 n : int
31 Number of trials.
32
33 Returns
34 -------
35 float
36 Half-width of the 68% Wilson score interval (~1 sigma).
37 """
38 if n == 0:
39 return 0.0
40 z = 1.0 # 1-sigma
41 p = k / n
42 denom = 1 + z**2 / n
43 spread = z * math.sqrt(p * (1 - p) / n + z**2 / (4 * n**2)) / denom
44 return spread
45
46
48 """Compute tracking benchmark metrics from simulation and reconstruction files.
49
50 Parameters
51 ----------
52 sim_file : str
53 Path to MC simulation ROOT file (contains cbmsim tree).
54 reco_file : str
55 Path to reconstruction ROOT file (contains ship_reco_sim tree).
56 geo_file : str
57 Path to geometry ROOT file.
58 purity_cut : float
59 Minimum hit purity fraction for a reco track to be considered matched.
60 min_hits : int
61 Minimum number of straw hits for reconstructibility.
62 min_stations : int
63 Minimum number of tracking stations crossed for reconstructibility.
64 """
65
67 self,
68 sim_file: str,
69 reco_file: str,
70 geo_file: str,
71 purity_cut: float = 0.70,
72 min_hits: int = 25,
73 min_stations: int = 3,
74 ) -> None:
75 self.purity_cut = purity_cut
76 self.min_hits = min_hits
77 self.min_stations = min_stations
78
79 self.f_sim = ROOT.TFile.Open(sim_file, "read")
80 self.sim_tree = self.f_sim["cbmsim"]
81
82 self.f_reco = ROOT.TFile.Open(reco_file, "read")
83 self.reco_tree = self.f_reco["ship_reco_sim"]
84
85 self.f_geo = ROOT.TFile.Open(geo_file, "read")
86
87 self.PDG = ROOT.TDatabasePDG.Instance()
88
89 self.metrics: dict[str, Any] = {}
90 self._histos: dict[str, Any] = {}
91
92 def _is_reconstructible(self, mc_track_id: int) -> bool:
93 """Check if an MC particle meets reconstructibility criteria.
94
95 A particle is reconstructible if it is a charged primary with
96 hits in >= min_stations tracking stations and >= min_hits total
97 straw hits. This matches the cuts in shipDigiReco.findTracks().
98 """
99 mc_track = self.sim_tree.MCTrack[mc_track_id]
100
101 # Must be primary (no mother)
102 if mc_track.GetMotherId() >= 0:
103 return False
104
105 # Must be charged
106 pdg_code = mc_track.GetPdgCode()
107 particle = self.PDG.GetParticle(pdg_code)
108 if particle is None or particle.Charge() == 0:
109 return False
110
111 # Count hits per station
112 stations: set[int] = set()
113 n_hits = 0
114 for hit in self.sim_tree.strawtubesPoint:
115 if hit.GetTrackID() != mc_track_id:
116 continue
117 n_hits += 1
118 det_id = hit.GetDetectorID()
119 station = int(det_id // 1_000_000)
120 stations.add(station)
121
122 return n_hits >= self.min_hits and len(stations) >= self.min_stations
123
124 def _get_ptruth_first(self, mc_track_id: int) -> tuple[float, float, float, float]:
125 """Get MC truth momentum at the first straw hit.
126
127 Follows the pattern from macro/ShipAna.py:getPtruthFirst().
128 """
129 for hit in self.sim_tree.strawtubesPoint:
130 if hit.GetTrackID() == mc_track_id:
131 px, py, pz = hit.GetPx(), hit.GetPy(), hit.GetPz()
132 p = math.sqrt(px**2 + py**2 + pz**2)
133 return p, px, py, pz
134 return -1.0, -1.0, -1.0, -1.0
135
136 def _get_truth_pos_first(self, mc_track_id: int) -> tuple[float, float, float]:
137 """Get MC truth position at the first straw hit."""
138 for hit in self.sim_tree.strawtubesPoint:
139 if hit.GetTrackID() == mc_track_id:
140 return hit.GetX(), hit.GetY(), hit.GetZ()
141 return 0.0, 0.0, 0.0
142
143 def _get_truth_slopes(self, mc_track_id: int) -> tuple[float, float]:
144 """Get MC truth track slopes tx=px/pz, ty=py/pz at first straw hit."""
145 for hit in self.sim_tree.strawtubesPoint:
146 if hit.GetTrackID() == mc_track_id:
147 px, py, pz = hit.GetPx(), hit.GetPy(), hit.GetPz()
148 if abs(pz) > 1e-10:
149 return px / pz, py / pz
150 return 0.0, 0.0
151 return 0.0, 0.0
152
153 def _fracMCsame(self, reco_track_idx: int) -> tuple[float, int]:
154 """Get the hit purity and dominant MC track ID for a reco track.
155
156 Uses the Tracklets branch to access hit indices, then checks
157 which MC track contributed most hits.
158 """
159 tracklet = self.reco_tree.Tracklets[reco_track_idx]
160 hit_indices = tracklet.getList()
161
162 track_counts: dict[int, int] = {}
163 n_hits = 0
164 for idx in hit_indices:
165 mc_id = self.sim_tree.strawtubesPoint[idx].GetTrackID()
166 track_counts[mc_id] = track_counts.get(mc_id, 0) + 1
167 n_hits += 1
168
169 if not track_counts:
170 return 0.0, -999
171
172 tmax = max(track_counts, key=track_counts.__getitem__)
173 frac = track_counts[tmax] / n_hits if n_hits > 0 else 0.0
174 return frac, tmax
175
176 def compute_metrics(self) -> dict[str, Any]:
177 """Run the full benchmark analysis over all events.
178
179 Returns
180 -------
181 dict
182 Dictionary of metrics compatible with compare_metrics.py format.
183 """
184 n_events = self.sim_tree.GetEntries()
185 n_reco_events = self.reco_tree.GetEntries()
186 if n_events != n_reco_events:
187 print(f"WARNING: sim has {n_events} events, reco has {n_reco_events}")
188 n_events = min(n_events, n_reco_events)
189
190 # Book histograms
191 h_dp_over_p = ROOT.TH1D("h_dp_over_p", "#Deltap/p;(p_{reco} - p_{truth})/p_{truth};Entries", 100, -0.5, 0.5)
192 h_dp_vs_p = ROOT.TH2D(
193 "h_dp_vs_p", "#Deltap/p vs p_{truth};p_{truth} [GeV/c];#Deltap/p", 50, 0, 120, 100, -0.5, 0.5
194 )
195 h_dx = ROOT.TH1D("h_dx", "#Deltax at first hit;x_{reco} - x_{truth} [cm];Entries", 100, -5.0, 5.0)
196 h_dy = ROOT.TH1D("h_dy", "#Deltay at first hit;y_{reco} - y_{truth} [cm];Entries", 100, -5.0, 5.0)
197 h_dtx = ROOT.TH1D("h_dtx", "#Deltat_{x};t_{x,reco} - t_{x,truth};Entries", 100, -0.01, 0.01)
198 h_dty = ROOT.TH1D("h_dty", "#Deltat_{y};t_{y,reco} - t_{y,truth};Entries", 100, -0.01, 0.01)
199 h_chi2ndf = ROOT.TH1D("h_chi2ndf", "#chi^{2}/ndf;#chi^{2}/ndf;Entries", 100, 0, 20)
200 h_p_truth = ROOT.TH1D("h_p_truth", "p_{truth} (reconstructible);p [GeV/c];Entries", 50, 0, 120)
201 h_p_matched = ROOT.TH1D("h_p_matched", "p_{truth} (matched);p [GeV/c];Entries", 50, 0, 120)
202
203 # Counters
204 n_reconstructible = 0
205 n_matched_mc = 0 # reconstructible MC particles with >= 1 matched reco track
206 n_total_reco = 0
207 n_matched_reco = 0 # reco tracks passing purity cut
208 n_clone_reco = 0 # extra matches beyond the first for same MC particle
209
210 for i_event in range(n_events):
211 self.sim_tree.GetEvent(i_event)
212 self.reco_tree.GetEvent(i_event)
213
214 # Find reconstructible MC particles
215 reconstructible_ids: set[int] = set()
216 n_mc_tracks = len(self.sim_tree.MCTrack)
217 for mc_id in range(n_mc_tracks):
218 if self._is_reconstructible(mc_id):
219 reconstructible_ids.add(mc_id)
220 p_truth, _, _, _ = self._get_ptruth_first(mc_id)
221 if p_truth > 0:
222 h_p_truth.Fill(p_truth)
223
224 n_reconstructible += len(reconstructible_ids)
225
226 # Match reco tracks to MC
227 n_reco = self.reco_tree.FitTracks.size()
228 n_total_reco += n_reco
229
230 # Track which MC particles have been matched in this event
231 matched_mc_this_event: set[int] = set()
232
233 for i_reco in range(n_reco):
234 track = self.reco_tree.FitTracks[i_reco]
235 fit_status = track.getFitStatus()
236 if not fit_status.isFitConverged():
237 continue
238
239 ndf = fit_status.getNdf()
240 if ndf <= 0:
241 continue
242 chi2 = fit_status.getChi2() / ndf
243 h_chi2ndf.Fill(chi2)
244
245 # Use fitTrack2MC for the MC link (already computed by fracMCsame)
246 mc_id = self.reco_tree.fitTrack2MC[i_reco]
247
248 # Recompute purity to apply our cut
249 frac, _dominant_id = self._fracMCsame(i_reco)
250
251 if frac < self.purity_cut:
252 # Ghost track
253 continue
254
255 n_matched_reco += 1
256
257 if mc_id in reconstructible_ids:
258 if mc_id not in matched_mc_this_event:
259 matched_mc_this_event.add(mc_id)
260 else:
261 n_clone_reco += 1
262
263 # Resolution histograms (use first match only for resolution)
264 p_truth, _, _, _ = self._get_ptruth_first(mc_id)
265 x_t, y_t, _ = self._get_truth_pos_first(mc_id)
266 tx_t, ty_t = self._get_truth_slopes(mc_id)
267
268 if p_truth > 0:
269 try:
270 fitted_state = track.getFittedState()
271 p_reco = fitted_state.getMomMag()
272 mom = fitted_state.getMom()
273 pos = fitted_state.getPos()
274
275 dp_over_p = (p_reco - p_truth) / p_truth
276 h_dp_over_p.Fill(dp_over_p)
277 h_dp_vs_p.Fill(p_truth, dp_over_p)
278
279 h_dx.Fill(pos.X() - x_t)
280 h_dy.Fill(pos.Y() - y_t)
281
282 pz_reco = mom.Z()
283 if abs(pz_reco) > 1e-10:
284 tx_reco = mom.X() / pz_reco
285 ty_reco = mom.Y() / pz_reco
286 h_dtx.Fill(tx_reco - tx_t)
287 h_dty.Fill(ty_reco - ty_t)
288
289 h_p_matched.Fill(p_truth)
290 except Exception:
291 pass
292
293 n_matched_mc += len(matched_mc_this_event)
294
295 # Compute metrics
296 n_ghost_reco = n_total_reco - n_matched_reco
297
298 efficiency = n_matched_mc / n_reconstructible if n_reconstructible > 0 else 0.0
299 efficiency_unc = wilson_interval(n_matched_mc, n_reconstructible)
300
301 clone_rate = n_clone_reco / n_matched_reco if n_matched_reco > 0 else 0.0
302 clone_rate_unc = wilson_interval(n_clone_reco, n_matched_reco)
303
304 ghost_rate = n_ghost_reco / n_total_reco if n_total_reco > 0 else 0.0
305 ghost_rate_unc = wilson_interval(n_ghost_reco, n_total_reco)
306
307 # Fit dp/p with Gaussian
308 dp_p_sigma = h_dp_over_p.GetRMS()
309 dp_p_sigma_unc = h_dp_over_p.GetRMSError()
310 if h_dp_over_p.GetEntries() > 20:
311 fit_result = h_dp_over_p.Fit("gaus", "QS")
312 if fit_result and int(fit_result) == 0:
313 dp_p_sigma = fit_result.Parameter(2)
314 dp_p_sigma_unc = fit_result.ParError(2)
315
316 self.metrics = {
317 "tracking_benchmark": {
318 "n_events": {"value": int(n_events), "compare": "exact"},
319 "n_reconstructible": {"value": int(n_reconstructible), "compare": "exact"},
320 "n_total_reco": {"value": int(n_total_reco), "compare": "exact"},
321 "efficiency": {
322 "value": round(efficiency, 6),
323 "uncertainty": round(efficiency_unc, 6),
324 "compare": "statistical",
325 },
326 "clone_rate": {
327 "value": round(clone_rate, 6),
328 "uncertainty": round(clone_rate_unc, 6),
329 "compare": "statistical",
330 },
331 "ghost_rate": {
332 "value": round(ghost_rate, 6),
333 "uncertainty": round(ghost_rate_unc, 6),
334 "compare": "statistical",
335 },
336 "dp_over_p_sigma": {
337 "value": round(dp_p_sigma, 6),
338 "uncertainty": round(dp_p_sigma_unc, 6),
339 "compare": "statistical",
340 },
341 "dx_rms": {
342 "value": round(h_dx.GetRMS(), 6),
343 "uncertainty": round(h_dx.GetRMSError(), 6),
344 "compare": "statistical",
345 },
346 "dy_rms": {
347 "value": round(h_dy.GetRMS(), 6),
348 "uncertainty": round(h_dy.GetRMSError(), 6),
349 "compare": "statistical",
350 },
351 "dtx_rms": {
352 "value": round(h_dtx.GetRMS(), 6),
353 "uncertainty": round(h_dtx.GetRMSError(), 6),
354 "compare": "statistical",
355 },
356 "dty_rms": {
357 "value": round(h_dty.GetRMS(), 6),
358 "uncertainty": round(h_dty.GetRMSError(), 6),
359 "compare": "statistical",
360 },
361 }
362 }
363
364 self._histos = {
365 "h_dp_over_p": h_dp_over_p,
366 "h_dp_vs_p": h_dp_vs_p,
367 "h_dx": h_dx,
368 "h_dy": h_dy,
369 "h_dtx": h_dtx,
370 "h_dty": h_dty,
371 "h_chi2ndf": h_chi2ndf,
372 "h_p_truth": h_p_truth,
373 "h_p_matched": h_p_matched,
374 }
375
376 return self.metrics
377
378 def save_json(self, output_path: str) -> None:
379 """Save metrics to JSON file."""
380 with open(output_path, "w") as f:
381 json.dump(self.metrics, f, indent=2)
382 print(f"Metrics saved to {output_path}")
383
384 def save_histograms(self, output_path: str) -> None:
385 """Save detailed histograms to a ROOT file."""
386 f_out = ROOT.TFile.Open(output_path, "recreate")
387 for name, hist in self._histos.items():
388 hist.Write(name)
389 f_out.Close()
390 print(f"Histograms saved to {output_path}")
391
392 def print_summary(self) -> None:
393 """Print a human-readable summary of the benchmark results."""
394 if not self.metrics:
395 print("No metrics computed yet. Call compute_metrics() first.")
396 return
397 m = self.metrics["tracking_benchmark"]
398 print("\n=== Tracking Benchmark Summary ===")
399 print(f" Events: {m['n_events']['value']}")
400 print(f" Reconstructible: {m['n_reconstructible']['value']}")
401 print(f" Total reco: {m['n_total_reco']['value']}")
402 print(f" Efficiency: {m['efficiency']['value']:.4f} +/- {m['efficiency']['uncertainty']:.4f}")
403 print(f" Clone rate: {m['clone_rate']['value']:.4f} +/- {m['clone_rate']['uncertainty']:.4f}")
404 print(f" Ghost rate: {m['ghost_rate']['value']:.4f} +/- {m['ghost_rate']['uncertainty']:.4f}")
405 print(f" dp/p sigma: {m['dp_over_p_sigma']['value']:.6f} +/- {m['dp_over_p_sigma']['uncertainty']:.6f}")
406 print(f" dx RMS: {m['dx_rms']['value']:.4f} +/- {m['dx_rms']['uncertainty']:.4f} cm")
407 print(f" dy RMS: {m['dy_rms']['value']:.4f} +/- {m['dy_rms']['uncertainty']:.4f} cm")
408 print(f" dtx RMS: {m['dtx_rms']['value']:.6f} +/- {m['dtx_rms']['uncertainty']:.6f}")
409 print(f" dty RMS: {m['dty_rms']['value']:.6f} +/- {m['dty_rms']['uncertainty']:.6f}")
410 print("==================================\n")
411
412 def __del__(self) -> None:
413 for f in [self.f_sim, self.f_reco, self.f_geo]:
414 if f and f.IsOpen():
415 f.Close()
None __init__(self, str sim_file, str reco_file, str geo_file, float purity_cut=0.70, int min_hits=25, int min_stations=3)
None save_histograms(self, str output_path)
tuple[float, float, float] _get_truth_pos_first(self, int mc_track_id)
tuple[float, float, float, float] _get_ptruth_first(self, int mc_track_id)
tuple[float, int] _fracMCsame(self, int reco_track_idx)
bool _is_reconstructible(self, int mc_track_id)
None save_json(self, str output_path)
tuple[float, float] _get_truth_slopes(self, int mc_track_id)
float wilson_interval(int k, int n)
int open(const char *, int)
Opens a file descriptor.