diff --git a/src/cyxtrax/io/load_maps.py b/src/cyxtrax/io/load_maps.py index 3364018fdd0fd5d5b22ea6a8c1f656c8d0caffa6..054578a897a550b14a4b8f830ba20732c124f59e 100644 --- a/src/cyxtrax/io/load_maps.py +++ b/src/cyxtrax/io/load_maps.py @@ -13,9 +13,14 @@ def load_atlas(load_path: Path) -> tuple[jnp.ndarray, jnp.ndarray, list[MeshObje with h5py.File(load_path, "r") as f: maps = jnp.array(f["/maps"][:]) points = jnp.array(f["/positions"][:]) + + mesh_object_dict = load_atlas_dict(load_path) + mesh_object_list = list(map(from_dict, mesh_object_dict)) + return maps, points, mesh_object_list + +def load_atlas_dict(load_path: Path): + with h5py.File(load_path, "r") as f: mesh_object_str = f.attrs["mesh_list"] mesh_object_dict = json.loads(mesh_object_str) - mesh_object_list = list(map(from_dict, mesh_object_dict)) - - return maps, points, mesh_object_list + return mesh_object_dict