FairShip
Loading...
Searching...
No Matches
ShipGeoConfig.py
Go to the documentation of this file.
1# SPDX-License-Identifier: LGPL-3.0-or-later
2# SPDX-FileCopyrightText: Copyright CERN for the benefit of the SHiP Collaboration
3
4from __future__ import annotations
5
6import json
7import os
8import pickle
9
10
11class AttrDict(dict):
12 """
13 dict class that can address its keys as fields, e.g.
14 d['key'] = 1
15 assert d.key == 1
16 """
17
18 def __init__(self, *args, **kwargs) -> None:
19 super().__init__(*args, **kwargs)
20 self.__dict__ = self
21
22 def clone(self) -> AttrDict:
23 result = AttrDict()
24 for k, v in self.items():
25 if isinstance(v, AttrDict):
26 result[k] = v.clone()
27 else:
28 result[k] = v
29 return result
30
31
33 def __init__(self, *args, **kwargs) -> None:
34 super().__init__(*args, **kwargs)
35
36 def loads(self, buff: bytes):
37 rv = pickle.loads(buff)
38 self.clear()
39 self.update(rv)
40 return self
41
42 def loads_json(self, json_str: str):
43 """Deserialize config from JSON string"""
44
45 def dict_to_attrdict(d):
46 """Recursively convert dict to AttrDict"""
47 if isinstance(d, dict):
48 result = AttrDict()
49 for k, v in d.items():
50 result[k] = dict_to_attrdict(v)
51 return result
52 elif isinstance(d, list):
53 return [dict_to_attrdict(item) for item in d]
54 else:
55 return d
56
57 rv = json.loads(json_str)
58 self.clear()
59 # Convert nested dicts to AttrDict
60 for k, v in rv.items():
61 self[k] = dict_to_attrdict(v)
62 return self
63
64 def clone(self) -> Config:
65 result = Config()
66 for k, v in self.items():
67 if isinstance(v, AttrDict):
68 result[k] = v.clone()
69 else:
70 result[k] = v
71 return result
72
73 def dumps(self) -> bytes:
74 return pickle.dumps(self)
75
76 def dumps_json(self) -> str:
77 """Serialize config to JSON string"""
78 return json.dumps(self, indent=2, default=str)
79
80 def load(self, filename):
81 with open(os.path.expandvars(filename), "rb") as fh:
82 self.loads(fh.read())
83 return self
84
85 def dump(self, filename) -> int:
86 with open(os.path.expandvars(filename), "wb") as fh:
87 return fh.write(self.dumps())
88
89 def __str__(self) -> str:
90 return "ShipGeoConfig:\n " + "\n ".join(
91 [f"{k}: {self[k].__str__()}" for k in sorted(self.keys()) if not k.startswith("_")]
92 )
93
94
95def load_from_root_file(root_file, key: str = "ShipGeo") -> Config:
96 """
97 Load configuration from ROOT file.
98
99 Automatically detects and handles both formats:
100 - New format: JSON string (stored as std::string or TObjString)
101 - Old format: Pickled Python object
102
103 Args:
104 root_file: Either a ROOT.TFile object or a string path to ROOT file
105 key: The key name for the stored config (default: 'ShipGeo')
106
107 Returns:
108 Config object with the loaded configuration
109 """
110 import ROOT
111
112 own_file = False
113 if isinstance(root_file, str):
114 root_file = ROOT.TFile.Open(root_file)
115 own_file = True
116
117 try:
118 # Get the object (could be std::string or TObjString)
119 config_obj = root_file.Get(key)
120 if not config_obj:
121 raise ValueError(f"No object with key '{key}' found in ROOT file")
122
123 # Convert to Python string
124 content_str = str(config_obj)
125
126 # Auto-detect format by checking first character
127 if content_str.startswith("{"):
128 # JSON format - parse it
129 config = Config()
130 config.loads_json(content_str)
131 else:
132 # Assume pickle format - unpickle it
133 # Convert to bytes for pickle (using latin-1 encoding)
134 pickle_bytes = content_str.encode("latin-1")
135 config = pickle.loads(pickle_bytes)
136
137 # Ensure it's a Config object (might be if it was pickled as Config)
138 if not isinstance(config, Config):
139 # Wrap in Config if needed
140 c = Config()
141 c.update(config)
142 config = c
143
144 return config
145
146 finally:
147 if own_file:
148 root_file.Close()
None __init__(self, *args, **kwargs)
AttrDict clone(self)
Config clone(self)
None __init__(self, *args, **kwargs)
def load(self, filename)
def loads_json(self, str json_str)
def loads(self, bytes buff)
int dump(self, filename)
int open(const char *, int)
Opens a file descriptor.