5"""Tracking performance benchmark metrics for straw tube spectrometer.
7Computes track finding efficiency, clone rate, ghost rate, and resolution
8metrics by comparing MC truth with reconstructed tracks. Designed to
9establish a GenFit baseline and later measure ACTS performance.
12from __future__
import annotations
20ROOT.gROOT.SetBatch(
True)
24 """Wilson score interval half-width for a binomial proportion.
36 Half-width of the 68% Wilson score interval (~1 sigma).
43 spread = z * math.sqrt(p * (1 - p) / n + z**2 / (4 * n**2)) / denom
48 """Compute tracking benchmark metrics from simulation and reconstruction files.
53 Path to MC simulation ROOT file (contains cbmsim tree).
55 Path to reconstruction ROOT file (contains ship_reco_sim tree).
57 Path to geometry ROOT file.
59 Minimum hit purity fraction for a reco track to be considered matched.
61 Minimum number of straw hits
for reconstructibility.
63 Minimum number of tracking stations crossed
for reconstructibility.
71 purity_cut: float = 0.70,
73 min_stations: int = 3,
79 self.
f_sim = ROOT.TFile.Open(sim_file,
"read")
82 self.
f_reco = ROOT.TFile.Open(reco_file,
"read")
85 self.
f_geo = ROOT.TFile.Open(geo_file,
"read")
87 self.
PDG = ROOT.TDatabasePDG.Instance()
89 self.
metrics: dict[str, Any] = {}
90 self.
_histos: dict[str, Any] = {}
93 """Check if an MC particle meets reconstructibility criteria.
95 A particle is reconstructible
if it
is a charged primary
with
96 hits
in >= min_stations tracking stations
and >= min_hits total
97 straw hits. This matches the cuts
in shipDigiReco.findTracks().
99 mc_track = self.sim_tree.MCTrack[mc_track_id]
102 if mc_track.GetMotherId() >= 0:
106 pdg_code = mc_track.GetPdgCode()
107 particle = self.
PDG.GetParticle(pdg_code)
108 if particle
is None or particle.Charge() == 0:
112 stations: set[int] = set()
114 for hit
in self.
sim_tree.strawtubesPoint:
115 if hit.GetTrackID() != mc_track_id:
118 det_id = hit.GetDetectorID()
119 station = int(det_id // 1_000_000)
120 stations.add(station)
125 """Get MC truth momentum at the first straw hit.
127 Follows the pattern from macro/ShipAna.py:getPtruthFirst().
129 for hit
in self.
sim_tree.strawtubesPoint:
130 if hit.GetTrackID() == mc_track_id:
131 px, py, pz = hit.GetPx(), hit.GetPy(), hit.GetPz()
132 p = math.sqrt(px**2 + py**2 + pz**2)
134 return -1.0, -1.0, -1.0, -1.0
137 """Get MC truth position at the first straw hit."""
138 for hit
in self.
sim_tree.strawtubesPoint:
139 if hit.GetTrackID() == mc_track_id:
140 return hit.GetX(), hit.GetY(), hit.GetZ()
144 """Get MC truth track slopes tx=px/pz, ty=py/pz at first straw hit."""
145 for hit
in self.
sim_tree.strawtubesPoint:
146 if hit.GetTrackID() == mc_track_id:
147 px, py, pz = hit.GetPx(), hit.GetPy(), hit.GetPz()
149 return px / pz, py / pz
154 """Get the hit purity and dominant MC track ID for a reco track.
156 Uses the Tracklets branch to access hit indices, then checks
157 which MC track contributed most hits.
159 tracklet = self.reco_tree.Tracklets[reco_track_idx]
160 hit_indices = tracklet.getList()
162 track_counts: dict[int, int] = {}
164 for idx
in hit_indices:
165 mc_id = self.
sim_tree.strawtubesPoint[idx].GetTrackID()
166 track_counts[mc_id] = track_counts.get(mc_id, 0) + 1
172 tmax = max(track_counts, key=track_counts.__getitem__)
173 frac = track_counts[tmax] / n_hits
if n_hits > 0
else 0.0
177 """Run the full benchmark analysis over all events.
182 Dictionary of metrics compatible with compare_metrics.py format.
184 n_events = self.sim_tree.GetEntries()
185 n_reco_events = self.reco_tree.GetEntries()
186 if n_events != n_reco_events:
187 print(f
"WARNING: sim has {n_events} events, reco has {n_reco_events}")
188 n_events = min(n_events, n_reco_events)
191 h_dp_over_p = ROOT.TH1D(
"h_dp_over_p",
"#Deltap/p;(p_{reco} - p_{truth})/p_{truth};Entries", 100, -0.5, 0.5)
192 h_dp_vs_p = ROOT.TH2D(
193 "h_dp_vs_p",
"#Deltap/p vs p_{truth};p_{truth} [GeV/c];#Deltap/p", 50, 0, 120, 100, -0.5, 0.5
195 h_dx = ROOT.TH1D(
"h_dx",
"#Deltax at first hit;x_{reco} - x_{truth} [cm];Entries", 100, -5.0, 5.0)
196 h_dy = ROOT.TH1D(
"h_dy",
"#Deltay at first hit;y_{reco} - y_{truth} [cm];Entries", 100, -5.0, 5.0)
197 h_dtx = ROOT.TH1D(
"h_dtx",
"#Deltat_{x};t_{x,reco} - t_{x,truth};Entries", 100, -0.01, 0.01)
198 h_dty = ROOT.TH1D(
"h_dty",
"#Deltat_{y};t_{y,reco} - t_{y,truth};Entries", 100, -0.01, 0.01)
199 h_chi2ndf = ROOT.TH1D(
"h_chi2ndf",
"#chi^{2}/ndf;#chi^{2}/ndf;Entries", 100, 0, 20)
200 h_p_truth = ROOT.TH1D(
"h_p_truth",
"p_{truth} (reconstructible);p [GeV/c];Entries", 50, 0, 120)
201 h_p_matched = ROOT.TH1D(
"h_p_matched",
"p_{truth} (matched);p [GeV/c];Entries", 50, 0, 120)
204 n_reconstructible = 0
210 for i_event
in range(n_events):
215 reconstructible_ids: set[int] = set()
216 n_mc_tracks = len(self.
sim_tree.MCTrack)
217 for mc_id
in range(n_mc_tracks):
219 reconstructible_ids.add(mc_id)
222 h_p_truth.Fill(p_truth)
224 n_reconstructible += len(reconstructible_ids)
228 n_total_reco += n_reco
231 matched_mc_this_event: set[int] = set()
233 for i_reco
in range(n_reco):
235 fit_status = track.getFitStatus()
236 if not fit_status.isFitConverged():
239 ndf = fit_status.getNdf()
242 chi2 = fit_status.getChi2() / ndf
246 mc_id = self.
reco_tree.fitTrack2MC[i_reco]
257 if mc_id
in reconstructible_ids:
258 if mc_id
not in matched_mc_this_event:
259 matched_mc_this_event.add(mc_id)
270 fitted_state = track.getFittedState()
271 p_reco = fitted_state.getMomMag()
272 mom = fitted_state.getMom()
273 pos = fitted_state.getPos()
275 dp_over_p = (p_reco - p_truth) / p_truth
276 h_dp_over_p.Fill(dp_over_p)
277 h_dp_vs_p.Fill(p_truth, dp_over_p)
279 h_dx.Fill(pos.X() - x_t)
280 h_dy.Fill(pos.Y() - y_t)
283 if abs(pz_reco) > 1e-10:
284 tx_reco = mom.X() / pz_reco
285 ty_reco = mom.Y() / pz_reco
286 h_dtx.Fill(tx_reco - tx_t)
287 h_dty.Fill(ty_reco - ty_t)
289 h_p_matched.Fill(p_truth)
293 n_matched_mc += len(matched_mc_this_event)
296 n_ghost_reco = n_total_reco - n_matched_reco
298 efficiency = n_matched_mc / n_reconstructible
if n_reconstructible > 0
else 0.0
301 clone_rate = n_clone_reco / n_matched_reco
if n_matched_reco > 0
else 0.0
304 ghost_rate = n_ghost_reco / n_total_reco
if n_total_reco > 0
else 0.0
308 dp_p_sigma = h_dp_over_p.GetRMS()
309 dp_p_sigma_unc = h_dp_over_p.GetRMSError()
310 if h_dp_over_p.GetEntries() > 20:
311 fit_result = h_dp_over_p.Fit(
"gaus",
"QS")
312 if fit_result
and int(fit_result) == 0:
313 dp_p_sigma = fit_result.Parameter(2)
314 dp_p_sigma_unc = fit_result.ParError(2)
317 "tracking_benchmark": {
318 "n_events": {
"value": int(n_events),
"compare":
"exact"},
319 "n_reconstructible": {
"value": int(n_reconstructible),
"compare":
"exact"},
320 "n_total_reco": {
"value": int(n_total_reco),
"compare":
"exact"},
322 "value": round(efficiency, 6),
323 "uncertainty": round(efficiency_unc, 6),
324 "compare":
"statistical",
327 "value": round(clone_rate, 6),
328 "uncertainty": round(clone_rate_unc, 6),
329 "compare":
"statistical",
332 "value": round(ghost_rate, 6),
333 "uncertainty": round(ghost_rate_unc, 6),
334 "compare":
"statistical",
337 "value": round(dp_p_sigma, 6),
338 "uncertainty": round(dp_p_sigma_unc, 6),
339 "compare":
"statistical",
342 "value": round(h_dx.GetRMS(), 6),
343 "uncertainty": round(h_dx.GetRMSError(), 6),
344 "compare":
"statistical",
347 "value": round(h_dy.GetRMS(), 6),
348 "uncertainty": round(h_dy.GetRMSError(), 6),
349 "compare":
"statistical",
352 "value": round(h_dtx.GetRMS(), 6),
353 "uncertainty": round(h_dtx.GetRMSError(), 6),
354 "compare":
"statistical",
357 "value": round(h_dty.GetRMS(), 6),
358 "uncertainty": round(h_dty.GetRMSError(), 6),
359 "compare":
"statistical",
365 "h_dp_over_p": h_dp_over_p,
366 "h_dp_vs_p": h_dp_vs_p,
371 "h_chi2ndf": h_chi2ndf,
372 "h_p_truth": h_p_truth,
373 "h_p_matched": h_p_matched,
379 """Save metrics to JSON file."""
380 with open(output_path,
"w")
as f:
381 json.dump(self.
metrics, f, indent=2)
382 print(f
"Metrics saved to {output_path}")
385 """Save detailed histograms to a ROOT file."""
386 f_out = ROOT.TFile.Open(output_path,
"recreate")
387 for name, hist
in self.
_histos.items():
390 print(f
"Histograms saved to {output_path}")
393 """Print a human-readable summary of the benchmark results."""
395 print(
"No metrics computed yet. Call compute_metrics() first.")
397 m = self.
metrics[
"tracking_benchmark"]
398 print(
"\n=== Tracking Benchmark Summary ===")
399 print(f
" Events: {m['n_events']['value']}")
400 print(f
" Reconstructible: {m['n_reconstructible']['value']}")
401 print(f
" Total reco: {m['n_total_reco']['value']}")
402 print(f
" Efficiency: {m['efficiency']['value']:.4f} +/- {m['efficiency']['uncertainty']:.4f}")
403 print(f
" Clone rate: {m['clone_rate']['value']:.4f} +/- {m['clone_rate']['uncertainty']:.4f}")
404 print(f
" Ghost rate: {m['ghost_rate']['value']:.4f} +/- {m['ghost_rate']['uncertainty']:.4f}")
405 print(f
" dp/p sigma: {m['dp_over_p_sigma']['value']:.6f} +/- {m['dp_over_p_sigma']['uncertainty']:.6f}")
406 print(f
" dx RMS: {m['dx_rms']['value']:.4f} +/- {m['dx_rms']['uncertainty']:.4f} cm")
407 print(f
" dy RMS: {m['dy_rms']['value']:.4f} +/- {m['dy_rms']['uncertainty']:.4f} cm")
408 print(f
" dtx RMS: {m['dtx_rms']['value']:.6f} +/- {m['dtx_rms']['uncertainty']:.6f}")
409 print(f
" dty RMS: {m['dty_rms']['value']:.6f} +/- {m['dty_rms']['uncertainty']:.6f}")
410 print(
"==================================\n")
dict[str, Any] compute_metrics(self)
None __init__(self, str sim_file, str reco_file, str geo_file, float purity_cut=0.70, int min_hits=25, int min_stations=3)
None save_histograms(self, str output_path)
tuple[float, float, float] _get_truth_pos_first(self, int mc_track_id)
tuple[float, float, float, float] _get_ptruth_first(self, int mc_track_id)
tuple[float, int] _fracMCsame(self, int reco_track_idx)
bool _is_reconstructible(self, int mc_track_id)
None save_json(self, str output_path)
tuple[float, float] _get_truth_slopes(self, int mc_track_id)
float wilson_interval(int k, int n)
int open(const char *, int)
Opens a file descriptor.