From a496530768e0cb4ce865367667071ca5c51d950b Mon Sep 17 00:00:00 2001
From: = <=>
Date: Fri, 22 Nov 2024 11:35:39 +0100
Subject: [PATCH] added save map

---
 scripts/01_aRTist_bridge.py                |  54 ++++---
 scripts/02_generate_atlas.py               |  23 +--
 scripts/03_load_atlas.py                   |  19 +--
 scripts/09_data_aug_pipeline.py            |   0
 src/cyxtrax/common/__init__.py             |   4 +-
 src/cyxtrax/io/__init__.py                 |   6 +-
 src/cyxtrax/io/save_map.py                 |  52 ++++++
 src/cyxtrax/mapping/__init__.py            |   5 +-
 src/cyxtrax/simulation/__init__.py         |   7 +-
 src/cyxtrax/simulation/artist_bridge.py    |   3 +-
 src/cyxtrax/simulation/model/__init__.py   |  10 +-
 src/cyxtrax/util/datasets/__init__.py      |  10 +-
 src/cyxtrax/util/visualisation/__init__.py |   7 +-
 src/cyxtrax/util/visualisation/plt.py      | 174 +++++++++++++--------
 14 files changed, 227 insertions(+), 147 deletions(-)
 delete mode 100644 scripts/09_data_aug_pipeline.py
 create mode 100644 src/cyxtrax/io/save_map.py

diff --git a/scripts/01_aRTist_bridge.py b/scripts/01_aRTist_bridge.py
index 2d93de9..0c99222 100644
--- a/scripts/01_aRTist_bridge.py
+++ b/scripts/01_aRTist_bridge.py
@@ -8,7 +8,7 @@ from pathlib import Path
 
 
 FOLDER = Path(__file__)
-TEMP_FOLDER = FOLDER.parent.parent / 'temp'
+TEMP_FOLDER = FOLDER.parent.parent / "temp"
 
 
 def main():
@@ -16,16 +16,14 @@ def main():
     # Create the cylindrical projection interface
     cyl_proj = CylindricalProjection()
 
-
     # Load the test object 'olaf'. It is automatically installed to your site-packages
-    print(f'STL Path: {scones_mesh.object_path}')
+    print(f"STL Path: {scones_mesh.object_path}")
 
     # Obect are loaded in a dataclass: MeshObject.
-    print(f'Mesh is a MeshObject: {isinstance(scones_mesh, MeshObject)}')
+    print(f"Mesh is a MeshObject: {isinstance(scones_mesh, MeshObject)}")
     scones_mesh.position_mm = np.array([0, 0, 0])
-    scones_mesh.orientation_quat = np.array([0., -0., 0., 0.1])
-    
-    
+    scones_mesh.orientation_quat = np.array([0.0, -0.0, 0.0, 0.1])
+
     # The `load_mesh` function load a list of MeshObject
     load_mesh(cyl_proj, [scones_mesh])
 
@@ -33,8 +31,12 @@ def main():
     map_position = np.array([0, 0, 0])
 
     # Generate local cylindrical projection
-    half_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=False)
-    full_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=True)
+    half_rays = cyl_proj.compute_projection(
+        map_position, output_full_ray_projection=False
+    )
+    full_rays = cyl_proj.compute_projection(
+        map_position, output_full_ray_projection=True
+    )
     # code_embedder:A end
 
     # Make some nice plots!
@@ -42,26 +44,26 @@ def main():
 
     ax1 = fig.add_subplot(121)
     ax1.imshow(half_rays)
-    ax1.set_xlabel(r'$Z$ / mm')
-    ax1.set_ylabel(r'$\alpha$ / rad')
-    ax1.set_title('Half Rays')
-    ax1.set_xticks([0, 1000, 2000]) 
-    ax1.set_xticklabels([r'$-3141.5$', r'$0.$',  r'$3141.5$']) 
-    ax1.set_yticks([0, 1000, 2000]) 
-    ax1.set_yticklabels([r'$-\pi$', r'$0.$',  r'$\pi$']) 
+    ax1.set_xlabel(r"$Z$ / mm")
+    ax1.set_ylabel(r"$\alpha$ / rad")
+    ax1.set_title("Half Rays")
+    ax1.set_xticks([0, 1000, 2000])
+    ax1.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
+    ax1.set_yticks([0, 1000, 2000])
+    ax1.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
 
     ax2 = fig.add_subplot(122)
     ax2.imshow(full_rays)
-    ax2.set_xlabel(r'$Z$ / mm')
-    ax2.set_ylabel(r'$\alpha$ / rad')
-    ax2.set_title('Full Rays')
-    ax2.set_xticks([0, 1000, 2000]) 
-    ax2.set_xticklabels([r'$-3141.5$', r'$0.$',  r'$3141.5$']) 
-    ax2.set_yticks([0, 1000, 2000]) 
-    ax2.set_yticklabels([r'$-\pi$', r'$0.$',  r'$\pi$']) 
+    ax2.set_xlabel(r"$Z$ / mm")
+    ax2.set_ylabel(r"$\alpha$ / rad")
+    ax2.set_title("Full Rays")
+    ax2.set_xticks([0, 1000, 2000])
+    ax2.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
+    ax2.set_yticks([0, 1000, 2000])
+    ax2.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
 
-    plt.savefig(TEMP_FOLDER / 'scones.png')
+    plt.savefig(TEMP_FOLDER / "scones.png")
 
 
-if __name__ == '__main__':
-    main()
\ No newline at end of file
+if __name__ == "__main__":
+    main()
diff --git a/scripts/02_generate_atlas.py b/scripts/02_generate_atlas.py
index d3417c6..a3401d6 100644
--- a/scripts/02_generate_atlas.py
+++ b/scripts/02_generate_atlas.py
@@ -6,7 +6,7 @@ from pathlib import Path
 
 # Just get the temp folder to store the atlas
 FOLDER = Path(__file__)
-TEMP_FOLDER = FOLDER.parent.parent / 'temp'
+TEMP_FOLDER = FOLDER.parent.parent / "temp"
 
 
 def main():
@@ -15,13 +15,18 @@ def main():
     cyl_proj = CylindricalProjection()
     load_mesh(cyl_proj, [olaf_mesh])
     # make an atlas
-    
-    save_path = record_atlas(cyl_proj, [olaf_mesh], 'test', TEMP_FOLDER, 
-                            map_bounding_box=[-10, 20], number_of_maps=2)
+
+    save_path = record_atlas(
+        cyl_proj,
+        [olaf_mesh],
+        "test",
+        TEMP_FOLDER,
+        map_bounding_box=[-10, 20],
+        number_of_maps=2,
+    )
     # code_embedder:A end
-    print(f'Atlas saved to the path: {save_path}')
-    
-    
+    print(f"Atlas saved to the path: {save_path}")
+
 
-if __name__ == '__main__':
-    main()
\ No newline at end of file
+if __name__ == "__main__":
+    main()
diff --git a/scripts/03_load_atlas.py b/scripts/03_load_atlas.py
index f20a8af..e0ed4ed 100644
--- a/scripts/03_load_atlas.py
+++ b/scripts/03_load_atlas.py
@@ -5,22 +5,23 @@ from pathlib import Path
 
 # Just get the temp folder to store the atlas
 FOLDER = Path(__file__)
-TEMP_FOLDER = FOLDER.parent.parent / 'temp'
+TEMP_FOLDER = FOLDER.parent.parent / "temp"
 
 
 def main():
     # code_embedder:A start
-    files = list(TEMP_FOLDER.glob('*.h5'))
+    files = list(TEMP_FOLDER.glob("*.h5"))
     if not files:
-        raise FileNotFoundError('No .h5 files in temp folder. Run Example 02!')
+        raise FileNotFoundError("No .h5 files in temp folder. Run Example 02!")
 
     maps, points, mesh_object_list = load_atlas(files[-1])
     # code_embedder:A end
-    print(f'Loaded Map Shape: {maps.shape}')
-    print(f'Loaded Map Centre Shape: {points.shape}')
-    print(f'First Object: {mesh_object_list[0]}')
+    print(f"Loaded Map Shape: {maps.shape}")
+    print(f"Loaded Map Centre Shape: {points.shape}")
+    print(f"First Object: {mesh_object_list[0]}")
 
-    voxel_maps(maps, TEMP_FOLDER / 'scones_maps.png', title='Scones Atlas')
+    voxel_maps(maps, TEMP_FOLDER / "scones_maps.png", title="Scones Atlas")
 
-if __name__ == '__main__':
-    main()
\ No newline at end of file
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/09_data_aug_pipeline.py b/scripts/09_data_aug_pipeline.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/cyxtrax/common/__init__.py b/src/cyxtrax/common/__init__.py
index 7c315e4..3b30905 100644
--- a/src/cyxtrax/common/__init__.py
+++ b/src/cyxtrax/common/__init__.py
@@ -1,6 +1,4 @@
 from .mesh_object import MeshObject
 
 
-__all__ = [
-    'MeshObject'
-]
\ No newline at end of file
+__all__ = ["MeshObject"]
diff --git a/src/cyxtrax/io/__init__.py b/src/cyxtrax/io/__init__.py
index 2fff89e..f73e110 100644
--- a/src/cyxtrax/io/__init__.py
+++ b/src/cyxtrax/io/__init__.py
@@ -1,8 +1,6 @@
 from .generate_atlas import record_atlas
 from .load_maps import load_atlas
+from .save_map import save_atlas
 
 
-__all__ = [
-    'record_atlas',
-    'load_atlas'
-]
\ No newline at end of file
+__all__ = ["record_atlas", "load_atlas", "save_atlas"]
diff --git a/src/cyxtrax/io/save_map.py b/src/cyxtrax/io/save_map.py
new file mode 100644
index 0000000..4290707
--- /dev/null
+++ b/src/cyxtrax/io/save_map.py
@@ -0,0 +1,52 @@
+from cyxtrax.common.mesh_object import MeshObject
+from pathlib import Path
+import numpy as np
+import h5py
+import json
+
+
+def to_dict(mesh_object=MeshObject) -> dict:
+    return mesh_object.as_dict()
+
+
+def save_atlas(
+    mesh_list: list[MeshObject],
+    maps: np.ndarray,
+    map_position: np.ndarray,
+    atlas_name: str,
+    save_folder: Path,
+):
+    save_path = save_folder / f"{atlas_name}.h5"
+    file = h5py.File(save_path, "a")
+
+    x_px = maps.shape[0]
+    y_px = maps.shape[1]
+    number_of_maps = maps.shape[2]
+
+    projection = file.require_dataset(
+        "maps",
+        shape=(
+            x_px,
+            y_px,
+            number_of_maps,
+        ),  # Initial shape with third dimension as 0
+        maxshape=(
+            x_px,
+            y_px,
+            number_of_maps,
+        ),
+        dtype=np.float32,
+    )
+
+    projection_points = file.require_dataset(
+        "positions",
+        shape=(3, number_of_maps**3),  # Initial shape with third dimension as 0
+        maxshape=(3, number_of_maps**3),
+        dtype=np.float32,
+    )
+
+    mesh_list_dict = list(map(to_dict, mesh_list))
+    file.attrs["mesh_list"] = json.dumps(mesh_list_dict)
+
+    projection[:] = maps
+    projection_points[:] = map_position
diff --git a/src/cyxtrax/mapping/__init__.py b/src/cyxtrax/mapping/__init__.py
index 122cf03..30956ad 100644
--- a/src/cyxtrax/mapping/__init__.py
+++ b/src/cyxtrax/mapping/__init__.py
@@ -1,6 +1,3 @@
 from .mapping import map_geometry_2_projection, map_source_2_cylinder
 
-__all__ = [
-    'map_geometry_2_projection',
-    'map_source_2_cylinder'
-]
+__all__ = ["map_geometry_2_projection", "map_source_2_cylinder"]
diff --git a/src/cyxtrax/simulation/__init__.py b/src/cyxtrax/simulation/__init__.py
index ceddce2..3b5fe4e 100644
--- a/src/cyxtrax/simulation/__init__.py
+++ b/src/cyxtrax/simulation/__init__.py
@@ -1,8 +1,3 @@
 from .artist_bridge import CylindricalProjection, load_mesh, SAVEMODES, utility
 
-__all__ = [
-    'CylindricalProjection',
-    'load_mesh',
-    'SAVEMODES',
-    'utility'
-]
\ No newline at end of file
+__all__ = ["CylindricalProjection", "load_mesh", "SAVEMODES", "utility"]
diff --git a/src/cyxtrax/simulation/artist_bridge.py b/src/cyxtrax/simulation/artist_bridge.py
index 2a472cf..9d35f23 100644
--- a/src/cyxtrax/simulation/artist_bridge.py
+++ b/src/cyxtrax/simulation/artist_bridge.py
@@ -6,8 +6,10 @@ except ModuleNotFoundError:
     warn(
         "The module `artistlib`is not installed. The simulation module is not 100\% usable! \nInstall: https://github.com/wittlsn/aRTist-PythonLib"
     )
+
     def API():
         return None
+
     utility = None
     SAVEMODES = None
 from pathlib import Path
@@ -44,7 +46,6 @@ def set_cone_mode(api: API):
     with importlib.resources.path("cyxtrax.data", "cone.aRTist") as file_path:
         print(f"Load: {file_path}")
         api.load_project(file_path)
-       
 
 
 class CylindricalProjection:
diff --git a/src/cyxtrax/simulation/model/__init__.py b/src/cyxtrax/simulation/model/__init__.py
index ddd257a..187de9f 100644
--- a/src/cyxtrax/simulation/model/__init__.py
+++ b/src/cyxtrax/simulation/model/__init__.py
@@ -3,8 +3,8 @@ from .noise import add_white_noise, add_gaussian_blur
 from .xray import add_material
 
 __all__ = [
-    'AugemntationParameter',
-    'add_white_noise',
-    'add_gaussian_blur',
-    'add_material'
-]
\ No newline at end of file
+    "AugemntationParameter",
+    "add_white_noise",
+    "add_gaussian_blur",
+    "add_material",
+]
diff --git a/src/cyxtrax/util/datasets/__init__.py b/src/cyxtrax/util/datasets/__init__.py
index 8388c33..3550991 100644
--- a/src/cyxtrax/util/datasets/__init__.py
+++ b/src/cyxtrax/util/datasets/__init__.py
@@ -2,8 +2,8 @@ from .merge import merge_atlas_and_dict, merge_atlases, merge_lists
 from .pipeline import generate_atlas_from_folder
 
 __all__ = [
-    'merge_atlas_and_dict',
-    'merge_atlases',
-    'merge_lists',
-    'generate_atlas_from_folder'
-]
\ No newline at end of file
+    "merge_atlas_and_dict",
+    "merge_atlases",
+    "merge_lists",
+    "generate_atlas_from_folder",
+]
diff --git a/src/cyxtrax/util/visualisation/__init__.py b/src/cyxtrax/util/visualisation/__init__.py
index 9951940..ce653df 100644
--- a/src/cyxtrax/util/visualisation/__init__.py
+++ b/src/cyxtrax/util/visualisation/__init__.py
@@ -1,9 +1,4 @@
 from .plt import save_map_plot_with_index, voxel_cone_maps, voxel_maps
 from .gif import make_gif
 
-__all__ = [
-    'save_map_plot_with_index',
-    'voxel_cone_maps',
-    'make_gif',
-    'voxel_maps'
-]
\ No newline at end of file
+__all__ = ["save_map_plot_with_index", "voxel_cone_maps", "make_gif", "voxel_maps"]
diff --git a/src/cyxtrax/util/visualisation/plt.py b/src/cyxtrax/util/visualisation/plt.py
index 29704f6..1fd3e55 100644
--- a/src/cyxtrax/util/visualisation/plt.py
+++ b/src/cyxtrax/util/visualisation/plt.py
@@ -3,17 +3,18 @@ from jax import numpy as jnp
 from pathlib import Path
 
 
-def save_map_plot_with_index(maps: jnp.ndarray, index: int, 
-                             save_path: Path, figsize=(8, 8), title: str = None): 
+def save_map_plot_with_index(
+    maps: jnp.ndarray, index: int, save_path: Path, figsize=(8, 8), title: str = None
+):
     fig = plt.figure(figsize=figsize)
     ax = fig.add_subplot(111)
     ax.imshow(maps[:, :, index])
-    ax.set_xlabel(r'$z$ / mm')
-    ax.set_ylabel(r'$\alpha$ / rad')
-    ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])]) 
-    ax.set_xticklabels([r'$-3141.5$', r'$0.$',  r'$3141.5$']) 
-    ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])]) 
-    ax.set_yticklabels([r'$-\pi$', r'$0.$',  r'$\pi$'])    
+    ax.set_xlabel(r"$z$ / mm")
+    ax.set_ylabel(r"$\alpha$ / rad")
+    ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])])
+    ax.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
+    ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])])
+    ax.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
 
     if title is not None:
         ax.set_title(title)
@@ -22,56 +23,70 @@ def save_map_plot_with_index(maps: jnp.ndarray, index: int,
 
     plt.savefig(save_path)
     plt.close(fig)
-    
 
-def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray, 
-                    cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray,  
-                    save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False):
+
+def voxel_cone_maps(
+    maps: jnp.ndarray,
+    projection: jnp.ndarray,
+    cone_coordinates: jnp.ndarray,
+    cylindrical_coordinates: jnp.ndarray,
+    save_path: Path,
+    figsize=(24, 10),
+    title: str = None,
+    vmin: bool = False,
+):
     fig = plt.figure(figsize=figsize)
     subfig1: figure.Figure
     subfig2: figure.Figure
     subfig1, subfig2 = fig.subfigures(1, 2)
     ax1 = subfig1.add_subplot(111)
-    ax1.set_title('Projection Mapping')
+    ax1.set_title("Projection Mapping")
     im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max())
 
     counter = 0
     number_of_maps = maps.shape[2]
 
     color = jnp.linspace(0, 1, 9)[:number_of_maps]
-    cmap = plt.get_cmap('hsv')
+    cmap = plt.get_cmap("hsv")
     colors = cmap(color)
-    
-    ax1.set_xlabel(r'$u$ / px')
-    ax1.set_ylabel(r'$v$ / px')
-    subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm')
 
-    
+    ax1.set_xlabel(r"$u$ / px")
+    ax1.set_ylabel(r"$v$ / px")
+    subfig1.colorbar(im, ax=ax1, orientation="vertical", label="Ray Length / mm")
+
     ax2 = subfig2.add_subplot(111)
-    ax2.set_title('Cylindrical Mapping')
+    ax2.set_title("Cylindrical Mapping")
     ax2.axis(False)
 
-    side_length = int(jnp.round(number_of_maps ** (1. / 3.)))
+    side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0)))
 
     for z in [0, side_length - 1]:
         for y in [0, side_length - 1]:
             for x in [0, side_length - 1]:
-                position = x + y * side_length + z * side_length ** 2
+                position = x + y * side_length + z * side_length**2
                 if counter >= number_of_maps:
                     continue
-                ax = subfig2.add_subplot(3, 3, counter+1)
+                ax = subfig2.add_subplot(3, 3, counter + 1)
                 if vmin:
                     ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max())
                 else:
                     ax.imshow(maps[:, :, position].T, vmin=maps.min())
-                ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position],  c=colors[counter])
-                ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter])
-                ax.set_ylabel(r'$z$ / mm')
-                ax.set_xlabel(r'$\alpha$ / rad')
-                ax.set_yticks([0, 1000, 2000]) 
-                ax.set_yticklabels([r'$-3141.5$', r'$0.$',  r'$3141.5$']) 
-                ax.set_xticks([0, 1000, 2000]) 
-                ax.set_xticklabels([r'$-\pi$', r'$0.$',  r'$\pi$']) 
+                ax.scatter(
+                    cylindrical_coordinates[0, position],
+                    cylindrical_coordinates[1, position],
+                    c=colors[counter],
+                )
+                ax1.scatter(
+                    cone_coordinates[0, position],
+                    cone_coordinates[1, position],
+                    c=colors[counter],
+                )
+                ax.set_ylabel(r"$z$ / mm")
+                ax.set_xlabel(r"$\alpha$ / rad")
+                ax.set_yticks([0, 1000, 2000])
+                ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
+                ax.set_xticks([0, 1000, 2000])
+                ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
 
                 counter += 1
 
@@ -82,53 +97,68 @@ def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray,
     plt.savefig(save_path)
 
 
-def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray, 
-                    cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray,  
-                    cylindrical_values: jnp.ndarray,
-                    save_path: Path, figsize=(24, 10), title: str = None):
+def voxel_cone_maps_values(
+    maps: jnp.ndarray,
+    projection: jnp.ndarray,
+    cone_coordinates: jnp.ndarray,
+    cylindrical_coordinates: jnp.ndarray,
+    cylindrical_values: jnp.ndarray,
+    save_path: Path,
+    figsize=(24, 10),
+    title: str = None,
+):
     fig = plt.figure(figsize=figsize)
     subfig1: figure.Figure
     subfig2: figure.Figure
     subfig1, subfig2 = fig.subfigures(1, 2)
     ax1 = subfig1.add_subplot(111)
-    ax1.set_title('Projection Mapping')
+    ax1.set_title("Projection Mapping")
     im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max())
 
     counter = 0
     number_of_maps = maps.shape[2]
 
     color = jnp.linspace(0, 1, 9)[:number_of_maps]
-    cmap = plt.get_cmap('hsv')
+    cmap = plt.get_cmap("hsv")
     colors = cmap(color)
-    
-    ax1.set_xlabel(r'$u$ / px')
-    ax1.set_ylabel(r'$v$ / px')
-    subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm')
 
-    
+    ax1.set_xlabel(r"$u$ / px")
+    ax1.set_ylabel(r"$v$ / px")
+    subfig1.colorbar(im, ax=ax1, orientation="vertical", label="Ray Length / mm")
+
     ax2 = subfig2.add_subplot(111)
-    ax2.set_title('Cylindrical Mapping')
+    ax2.set_title("Cylindrical Mapping")
     ax2.axis(False)
 
-    side_length = int(jnp.round(number_of_maps ** (1. / 3.)))
+    side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0)))
 
     for z in [0, side_length - 1]:
         for y in [0, side_length - 1]:
             for x in [0, side_length - 1]:
-                position = x + y * side_length + z * side_length ** 2
+                position = x + y * side_length + z * side_length**2
                 if counter >= number_of_maps:
                     continue
-                ax = subfig2.add_subplot(3, 3, counter+1)
-                ax.scatter(cylindrical_values[0], cylindrical_values[1], )
-                ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position],  c=colors[counter])
-                ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter])
-                ax.set_ylabel(r'$z$ / mm')
-                ax.set_xlabel(r'$\alpha$ / rad')
-                ax.set_yticks([0, 1000, 2000]) 
-                ax.set_yticklabels([r'$-3141.5$', r'$0.$',  r'$3141.5$']) 
-                ax.set_xticks([0, 1000, 2000]) 
-                ax.set_xticklabels([r'$-\pi$', r'$0.$',  r'$\pi$']) 
-
+                ax = subfig2.add_subplot(3, 3, counter + 1)
+                ax.scatter(
+                    cylindrical_values[0],
+                    cylindrical_values[1],
+                )
+                ax.scatter(
+                    cylindrical_coordinates[0, position],
+                    cylindrical_coordinates[1, position],
+                    c=colors[counter],
+                )
+                ax1.scatter(
+                    cone_coordinates[0, position],
+                    cone_coordinates[1, position],
+                    c=colors[counter],
+                )
+                ax.set_ylabel(r"$z$ / mm")
+                ax.set_xlabel(r"$\alpha$ / rad")
+                ax.set_yticks([0, 1000, 2000])
+                ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
+                ax.set_xticks([0, 1000, 2000])
+                ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
 
                 counter += 1
 
@@ -139,38 +169,44 @@ def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray,
     plt.savefig(save_path)
 
 
-def voxel_maps(maps: jnp.ndarray, save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False):
+def voxel_maps(
+    maps: jnp.ndarray,
+    save_path: Path,
+    figsize=(24, 10),
+    title: str = None,
+    vmin: bool = False,
+):
     fig = plt.figure(figsize=figsize)
 
     ax2 = fig.add_subplot(111)
-    ax2.set_title('Cylindrical Mapping')
+    ax2.set_title("Cylindrical Mapping")
     ax2.axis(False)
 
     number_of_maps = maps.shape[2]
-    side_length = int(jnp.round(number_of_maps ** (1. / 3.)))
+    side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0)))
     counter = 0
     for z in [0, side_length - 1]:
         for y in [0, side_length - 1]:
             for x in [0, side_length - 1]:
-                position = x + y * side_length + z * side_length ** 2
+                position = x + y * side_length + z * side_length**2
                 if counter >= number_of_maps:
                     continue
-                ax = fig.add_subplot(3, 3, counter+1)
+                ax = fig.add_subplot(3, 3, counter + 1)
                 if vmin:
                     ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max())
                 else:
                     ax.imshow(maps[:, :, position].T, vmin=maps.min())
-                ax.set_ylabel(r'$z$ / mm')
-                ax.set_xlabel(r'$\alpha$ / rad')
-                ax.set_yticks([0, 1000, 2000]) 
-                ax.set_yticklabels([r'$-3141.5$', r'$0.$',  r'$3141.5$']) 
-                ax.set_xticks([0, 1000, 2000]) 
-                ax.set_xticklabels([r'$-\pi$', r'$0.$',  r'$\pi$']) 
+                ax.set_ylabel(r"$z$ / mm")
+                ax.set_xlabel(r"$\alpha$ / rad")
+                ax.set_yticks([0, 1000, 2000])
+                ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
+                ax.set_xticks([0, 1000, 2000])
+                ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
 
                 counter += 1
-                
+
     if title is not None:
         plt.title(title)
 
     plt.tight_layout()
-    plt.savefig(save_path)
\ No newline at end of file
+    plt.savefig(save_path)
-- 
GitLab