diff --git a/scripts/01_aRTist_bridge.py b/scripts/01_aRTist_bridge.py index 2d93de9ed2a942eca9d6c6b1a6d6d00e61f118b4..0c992223e735ce532768b5bee1e00ac4707d2a1d 100644 --- a/scripts/01_aRTist_bridge.py +++ b/scripts/01_aRTist_bridge.py @@ -8,7 +8,7 @@ from pathlib import Path FOLDER = Path(__file__) -TEMP_FOLDER = FOLDER.parent.parent / 'temp' +TEMP_FOLDER = FOLDER.parent.parent / "temp" def main(): @@ -16,16 +16,14 @@ def main(): # Create the cylindrical projection interface cyl_proj = CylindricalProjection() - # Load the test object 'olaf'. It is automatically installed to your site-packages - print(f'STL Path: {scones_mesh.object_path}') + print(f"STL Path: {scones_mesh.object_path}") # Obect are loaded in a dataclass: MeshObject. - print(f'Mesh is a MeshObject: {isinstance(scones_mesh, MeshObject)}') + print(f"Mesh is a MeshObject: {isinstance(scones_mesh, MeshObject)}") scones_mesh.position_mm = np.array([0, 0, 0]) - scones_mesh.orientation_quat = np.array([0., -0., 0., 0.1]) - - + scones_mesh.orientation_quat = np.array([0.0, -0.0, 0.0, 0.1]) + # The `load_mesh` function load a list of MeshObject load_mesh(cyl_proj, [scones_mesh]) @@ -33,8 +31,12 @@ def main(): map_position = np.array([0, 0, 0]) # Generate local cylindrical projection - half_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=False) - full_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=True) + half_rays = cyl_proj.compute_projection( + map_position, output_full_ray_projection=False + ) + full_rays = cyl_proj.compute_projection( + map_position, output_full_ray_projection=True + ) # code_embedder:A end # Make some nice plots! @@ -42,26 +44,26 @@ def main(): ax1 = fig.add_subplot(121) ax1.imshow(half_rays) - ax1.set_xlabel(r'$Z$ / mm') - ax1.set_ylabel(r'$\alpha$ / rad') - ax1.set_title('Half Rays') - ax1.set_xticks([0, 1000, 2000]) - ax1.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax1.set_yticks([0, 1000, 2000]) - ax1.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax1.set_xlabel(r"$Z$ / mm") + ax1.set_ylabel(r"$\alpha$ / rad") + ax1.set_title("Half Rays") + ax1.set_xticks([0, 1000, 2000]) + ax1.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax1.set_yticks([0, 1000, 2000]) + ax1.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) ax2 = fig.add_subplot(122) ax2.imshow(full_rays) - ax2.set_xlabel(r'$Z$ / mm') - ax2.set_ylabel(r'$\alpha$ / rad') - ax2.set_title('Full Rays') - ax2.set_xticks([0, 1000, 2000]) - ax2.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax2.set_yticks([0, 1000, 2000]) - ax2.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax2.set_xlabel(r"$Z$ / mm") + ax2.set_ylabel(r"$\alpha$ / rad") + ax2.set_title("Full Rays") + ax2.set_xticks([0, 1000, 2000]) + ax2.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax2.set_yticks([0, 1000, 2000]) + ax2.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) - plt.savefig(TEMP_FOLDER / 'scones.png') + plt.savefig(TEMP_FOLDER / "scones.png") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/scripts/02_generate_atlas.py b/scripts/02_generate_atlas.py index d3417c626bb60df82a64937dd7693b7311e549a9..a3401d6aeac5c37fd0945a8664b12c33f8b54a82 100644 --- a/scripts/02_generate_atlas.py +++ b/scripts/02_generate_atlas.py @@ -6,7 +6,7 @@ from pathlib import Path # Just get the temp folder to store the atlas FOLDER = Path(__file__) -TEMP_FOLDER = FOLDER.parent.parent / 'temp' +TEMP_FOLDER = FOLDER.parent.parent / "temp" def main(): @@ -15,13 +15,18 @@ def main(): cyl_proj = CylindricalProjection() load_mesh(cyl_proj, [olaf_mesh]) # make an atlas - - save_path = record_atlas(cyl_proj, [olaf_mesh], 'test', TEMP_FOLDER, - map_bounding_box=[-10, 20], number_of_maps=2) + + save_path = record_atlas( + cyl_proj, + [olaf_mesh], + "test", + TEMP_FOLDER, + map_bounding_box=[-10, 20], + number_of_maps=2, + ) # code_embedder:A end - print(f'Atlas saved to the path: {save_path}') - - + print(f"Atlas saved to the path: {save_path}") + -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/scripts/03_load_atlas.py b/scripts/03_load_atlas.py index f20a8afcc7ce574865d2a7e765e171a003bdb044..e0ed4ed656b35843f68abcfc8ed5c19ff9d583b0 100644 --- a/scripts/03_load_atlas.py +++ b/scripts/03_load_atlas.py @@ -5,22 +5,23 @@ from pathlib import Path # Just get the temp folder to store the atlas FOLDER = Path(__file__) -TEMP_FOLDER = FOLDER.parent.parent / 'temp' +TEMP_FOLDER = FOLDER.parent.parent / "temp" def main(): # code_embedder:A start - files = list(TEMP_FOLDER.glob('*.h5')) + files = list(TEMP_FOLDER.glob("*.h5")) if not files: - raise FileNotFoundError('No .h5 files in temp folder. Run Example 02!') + raise FileNotFoundError("No .h5 files in temp folder. Run Example 02!") maps, points, mesh_object_list = load_atlas(files[-1]) # code_embedder:A end - print(f'Loaded Map Shape: {maps.shape}') - print(f'Loaded Map Centre Shape: {points.shape}') - print(f'First Object: {mesh_object_list[0]}') + print(f"Loaded Map Shape: {maps.shape}") + print(f"Loaded Map Centre Shape: {points.shape}") + print(f"First Object: {mesh_object_list[0]}") - voxel_maps(maps, TEMP_FOLDER / 'scones_maps.png', title='Scones Atlas') + voxel_maps(maps, TEMP_FOLDER / "scones_maps.png", title="Scones Atlas") -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/scripts/09_data_aug_pipeline.py b/scripts/09_data_aug_pipeline.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/src/cyxtrax/common/__init__.py b/src/cyxtrax/common/__init__.py index 7c315e4bdfc523b5f0b0ecc157a3dcc366044e11..3b30905a6749b89d564f2cb19c98ed62b7be642f 100644 --- a/src/cyxtrax/common/__init__.py +++ b/src/cyxtrax/common/__init__.py @@ -1,6 +1,4 @@ from .mesh_object import MeshObject -__all__ = [ - 'MeshObject' -] \ No newline at end of file +__all__ = ["MeshObject"] diff --git a/src/cyxtrax/io/__init__.py b/src/cyxtrax/io/__init__.py index 2fff89e6e43e6ba570c599b610827e48cf68a198..f73e11006a7178bea608543438dd320b976650be 100644 --- a/src/cyxtrax/io/__init__.py +++ b/src/cyxtrax/io/__init__.py @@ -1,8 +1,6 @@ from .generate_atlas import record_atlas from .load_maps import load_atlas +from .save_map import save_atlas -__all__ = [ - 'record_atlas', - 'load_atlas' -] \ No newline at end of file +__all__ = ["record_atlas", "load_atlas", "save_atlas"] diff --git a/src/cyxtrax/io/save_map.py b/src/cyxtrax/io/save_map.py new file mode 100644 index 0000000000000000000000000000000000000000..429070770cf312eb02ded5406348e7e3d97b0f27 --- /dev/null +++ b/src/cyxtrax/io/save_map.py @@ -0,0 +1,52 @@ +from cyxtrax.common.mesh_object import MeshObject +from pathlib import Path +import numpy as np +import h5py +import json + + +def to_dict(mesh_object=MeshObject) -> dict: + return mesh_object.as_dict() + + +def save_atlas( + mesh_list: list[MeshObject], + maps: np.ndarray, + map_position: np.ndarray, + atlas_name: str, + save_folder: Path, +): + save_path = save_folder / f"{atlas_name}.h5" + file = h5py.File(save_path, "a") + + x_px = maps.shape[0] + y_px = maps.shape[1] + number_of_maps = maps.shape[2] + + projection = file.require_dataset( + "maps", + shape=( + x_px, + y_px, + number_of_maps, + ), # Initial shape with third dimension as 0 + maxshape=( + x_px, + y_px, + number_of_maps, + ), + dtype=np.float32, + ) + + projection_points = file.require_dataset( + "positions", + shape=(3, number_of_maps**3), # Initial shape with third dimension as 0 + maxshape=(3, number_of_maps**3), + dtype=np.float32, + ) + + mesh_list_dict = list(map(to_dict, mesh_list)) + file.attrs["mesh_list"] = json.dumps(mesh_list_dict) + + projection[:] = maps + projection_points[:] = map_position diff --git a/src/cyxtrax/mapping/__init__.py b/src/cyxtrax/mapping/__init__.py index 122cf03b2fe8d11b76711526d46360f17f553186..30956ad80a26f94a42a7f94c9dd2321cbc40eae9 100644 --- a/src/cyxtrax/mapping/__init__.py +++ b/src/cyxtrax/mapping/__init__.py @@ -1,6 +1,3 @@ from .mapping import map_geometry_2_projection, map_source_2_cylinder -__all__ = [ - 'map_geometry_2_projection', - 'map_source_2_cylinder' -] +__all__ = ["map_geometry_2_projection", "map_source_2_cylinder"] diff --git a/src/cyxtrax/simulation/__init__.py b/src/cyxtrax/simulation/__init__.py index ceddce2bc31b45b48999c027cb574f4fba0a2b2c..3b5fe4ee7274a17030859c9147303a56f39d650d 100644 --- a/src/cyxtrax/simulation/__init__.py +++ b/src/cyxtrax/simulation/__init__.py @@ -1,8 +1,3 @@ from .artist_bridge import CylindricalProjection, load_mesh, SAVEMODES, utility -__all__ = [ - 'CylindricalProjection', - 'load_mesh', - 'SAVEMODES', - 'utility' -] \ No newline at end of file +__all__ = ["CylindricalProjection", "load_mesh", "SAVEMODES", "utility"] diff --git a/src/cyxtrax/simulation/artist_bridge.py b/src/cyxtrax/simulation/artist_bridge.py index 2a472cff07340b599498569643aaba862269c335..9d35f2300415c243d9e8be4ee535bc9c3d4ca7b7 100644 --- a/src/cyxtrax/simulation/artist_bridge.py +++ b/src/cyxtrax/simulation/artist_bridge.py @@ -6,8 +6,10 @@ except ModuleNotFoundError: warn( "The module `artistlib`is not installed. The simulation module is not 100\% usable! \nInstall: https://github.com/wittlsn/aRTist-PythonLib" ) + def API(): return None + utility = None SAVEMODES = None from pathlib import Path @@ -44,7 +46,6 @@ def set_cone_mode(api: API): with importlib.resources.path("cyxtrax.data", "cone.aRTist") as file_path: print(f"Load: {file_path}") api.load_project(file_path) - class CylindricalProjection: diff --git a/src/cyxtrax/simulation/model/__init__.py b/src/cyxtrax/simulation/model/__init__.py index ddd257a4ea7a8a6ad86558ce53971d1a321947ee..187de9f0bbf17f07c8294967c4693fb5539be783 100644 --- a/src/cyxtrax/simulation/model/__init__.py +++ b/src/cyxtrax/simulation/model/__init__.py @@ -3,8 +3,8 @@ from .noise import add_white_noise, add_gaussian_blur from .xray import add_material __all__ = [ - 'AugemntationParameter', - 'add_white_noise', - 'add_gaussian_blur', - 'add_material' -] \ No newline at end of file + "AugemntationParameter", + "add_white_noise", + "add_gaussian_blur", + "add_material", +] diff --git a/src/cyxtrax/util/datasets/__init__.py b/src/cyxtrax/util/datasets/__init__.py index 8388c33666b5dd9fd516da6dab5af5ce7b9f3d64..35509910e32257a2813cd6ab5bec1cb1a44bd91b 100644 --- a/src/cyxtrax/util/datasets/__init__.py +++ b/src/cyxtrax/util/datasets/__init__.py @@ -2,8 +2,8 @@ from .merge import merge_atlas_and_dict, merge_atlases, merge_lists from .pipeline import generate_atlas_from_folder __all__ = [ - 'merge_atlas_and_dict', - 'merge_atlases', - 'merge_lists', - 'generate_atlas_from_folder' -] \ No newline at end of file + "merge_atlas_and_dict", + "merge_atlases", + "merge_lists", + "generate_atlas_from_folder", +] diff --git a/src/cyxtrax/util/visualisation/__init__.py b/src/cyxtrax/util/visualisation/__init__.py index 995194029d4ad90bb3a622024fdb845cc1096783..ce653dfb43782a1c2720b84d004df0d7f779ac7f 100644 --- a/src/cyxtrax/util/visualisation/__init__.py +++ b/src/cyxtrax/util/visualisation/__init__.py @@ -1,9 +1,4 @@ from .plt import save_map_plot_with_index, voxel_cone_maps, voxel_maps from .gif import make_gif -__all__ = [ - 'save_map_plot_with_index', - 'voxel_cone_maps', - 'make_gif', - 'voxel_maps' -] \ No newline at end of file +__all__ = ["save_map_plot_with_index", "voxel_cone_maps", "make_gif", "voxel_maps"] diff --git a/src/cyxtrax/util/visualisation/plt.py b/src/cyxtrax/util/visualisation/plt.py index 29704f68e05e151f2855e1ab6c9be46b4f02e1f8..1fd3e55c3765a78cfa3b9a4d930a2323ce453a9e 100644 --- a/src/cyxtrax/util/visualisation/plt.py +++ b/src/cyxtrax/util/visualisation/plt.py @@ -3,17 +3,18 @@ from jax import numpy as jnp from pathlib import Path -def save_map_plot_with_index(maps: jnp.ndarray, index: int, - save_path: Path, figsize=(8, 8), title: str = None): +def save_map_plot_with_index( + maps: jnp.ndarray, index: int, save_path: Path, figsize=(8, 8), title: str = None +): fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) ax.imshow(maps[:, :, index]) - ax.set_xlabel(r'$z$ / mm') - ax.set_ylabel(r'$\alpha$ / rad') - ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])]) - ax.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])]) - ax.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax.set_xlabel(r"$z$ / mm") + ax.set_ylabel(r"$\alpha$ / rad") + ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])]) + ax.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])]) + ax.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) if title is not None: ax.set_title(title) @@ -22,56 +23,70 @@ def save_map_plot_with_index(maps: jnp.ndarray, index: int, plt.savefig(save_path) plt.close(fig) - -def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray, - cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray, - save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False): + +def voxel_cone_maps( + maps: jnp.ndarray, + projection: jnp.ndarray, + cone_coordinates: jnp.ndarray, + cylindrical_coordinates: jnp.ndarray, + save_path: Path, + figsize=(24, 10), + title: str = None, + vmin: bool = False, +): fig = plt.figure(figsize=figsize) subfig1: figure.Figure subfig2: figure.Figure subfig1, subfig2 = fig.subfigures(1, 2) ax1 = subfig1.add_subplot(111) - ax1.set_title('Projection Mapping') + ax1.set_title("Projection Mapping") im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max()) counter = 0 number_of_maps = maps.shape[2] color = jnp.linspace(0, 1, 9)[:number_of_maps] - cmap = plt.get_cmap('hsv') + cmap = plt.get_cmap("hsv") colors = cmap(color) - - ax1.set_xlabel(r'$u$ / px') - ax1.set_ylabel(r'$v$ / px') - subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm') - + ax1.set_xlabel(r"$u$ / px") + ax1.set_ylabel(r"$v$ / px") + subfig1.colorbar(im, ax=ax1, orientation="vertical", label="Ray Length / mm") + ax2 = subfig2.add_subplot(111) - ax2.set_title('Cylindrical Mapping') + ax2.set_title("Cylindrical Mapping") ax2.axis(False) - side_length = int(jnp.round(number_of_maps ** (1. / 3.))) + side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0))) for z in [0, side_length - 1]: for y in [0, side_length - 1]: for x in [0, side_length - 1]: - position = x + y * side_length + z * side_length ** 2 + position = x + y * side_length + z * side_length**2 if counter >= number_of_maps: continue - ax = subfig2.add_subplot(3, 3, counter+1) + ax = subfig2.add_subplot(3, 3, counter + 1) if vmin: ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max()) else: ax.imshow(maps[:, :, position].T, vmin=maps.min()) - ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position], c=colors[counter]) - ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter]) - ax.set_ylabel(r'$z$ / mm') - ax.set_xlabel(r'$\alpha$ / rad') - ax.set_yticks([0, 1000, 2000]) - ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax.set_xticks([0, 1000, 2000]) - ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax.scatter( + cylindrical_coordinates[0, position], + cylindrical_coordinates[1, position], + c=colors[counter], + ) + ax1.scatter( + cone_coordinates[0, position], + cone_coordinates[1, position], + c=colors[counter], + ) + ax.set_ylabel(r"$z$ / mm") + ax.set_xlabel(r"$\alpha$ / rad") + ax.set_yticks([0, 1000, 2000]) + ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax.set_xticks([0, 1000, 2000]) + ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) counter += 1 @@ -82,53 +97,68 @@ def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray, plt.savefig(save_path) -def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray, - cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray, - cylindrical_values: jnp.ndarray, - save_path: Path, figsize=(24, 10), title: str = None): +def voxel_cone_maps_values( + maps: jnp.ndarray, + projection: jnp.ndarray, + cone_coordinates: jnp.ndarray, + cylindrical_coordinates: jnp.ndarray, + cylindrical_values: jnp.ndarray, + save_path: Path, + figsize=(24, 10), + title: str = None, +): fig = plt.figure(figsize=figsize) subfig1: figure.Figure subfig2: figure.Figure subfig1, subfig2 = fig.subfigures(1, 2) ax1 = subfig1.add_subplot(111) - ax1.set_title('Projection Mapping') + ax1.set_title("Projection Mapping") im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max()) counter = 0 number_of_maps = maps.shape[2] color = jnp.linspace(0, 1, 9)[:number_of_maps] - cmap = plt.get_cmap('hsv') + cmap = plt.get_cmap("hsv") colors = cmap(color) - - ax1.set_xlabel(r'$u$ / px') - ax1.set_ylabel(r'$v$ / px') - subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm') - + ax1.set_xlabel(r"$u$ / px") + ax1.set_ylabel(r"$v$ / px") + subfig1.colorbar(im, ax=ax1, orientation="vertical", label="Ray Length / mm") + ax2 = subfig2.add_subplot(111) - ax2.set_title('Cylindrical Mapping') + ax2.set_title("Cylindrical Mapping") ax2.axis(False) - side_length = int(jnp.round(number_of_maps ** (1. / 3.))) + side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0))) for z in [0, side_length - 1]: for y in [0, side_length - 1]: for x in [0, side_length - 1]: - position = x + y * side_length + z * side_length ** 2 + position = x + y * side_length + z * side_length**2 if counter >= number_of_maps: continue - ax = subfig2.add_subplot(3, 3, counter+1) - ax.scatter(cylindrical_values[0], cylindrical_values[1], ) - ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position], c=colors[counter]) - ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter]) - ax.set_ylabel(r'$z$ / mm') - ax.set_xlabel(r'$\alpha$ / rad') - ax.set_yticks([0, 1000, 2000]) - ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax.set_xticks([0, 1000, 2000]) - ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) - + ax = subfig2.add_subplot(3, 3, counter + 1) + ax.scatter( + cylindrical_values[0], + cylindrical_values[1], + ) + ax.scatter( + cylindrical_coordinates[0, position], + cylindrical_coordinates[1, position], + c=colors[counter], + ) + ax1.scatter( + cone_coordinates[0, position], + cone_coordinates[1, position], + c=colors[counter], + ) + ax.set_ylabel(r"$z$ / mm") + ax.set_xlabel(r"$\alpha$ / rad") + ax.set_yticks([0, 1000, 2000]) + ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax.set_xticks([0, 1000, 2000]) + ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) counter += 1 @@ -139,38 +169,44 @@ def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray, plt.savefig(save_path) -def voxel_maps(maps: jnp.ndarray, save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False): +def voxel_maps( + maps: jnp.ndarray, + save_path: Path, + figsize=(24, 10), + title: str = None, + vmin: bool = False, +): fig = plt.figure(figsize=figsize) ax2 = fig.add_subplot(111) - ax2.set_title('Cylindrical Mapping') + ax2.set_title("Cylindrical Mapping") ax2.axis(False) number_of_maps = maps.shape[2] - side_length = int(jnp.round(number_of_maps ** (1. / 3.))) + side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0))) counter = 0 for z in [0, side_length - 1]: for y in [0, side_length - 1]: for x in [0, side_length - 1]: - position = x + y * side_length + z * side_length ** 2 + position = x + y * side_length + z * side_length**2 if counter >= number_of_maps: continue - ax = fig.add_subplot(3, 3, counter+1) + ax = fig.add_subplot(3, 3, counter + 1) if vmin: ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max()) else: ax.imshow(maps[:, :, position].T, vmin=maps.min()) - ax.set_ylabel(r'$z$ / mm') - ax.set_xlabel(r'$\alpha$ / rad') - ax.set_yticks([0, 1000, 2000]) - ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax.set_xticks([0, 1000, 2000]) - ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax.set_ylabel(r"$z$ / mm") + ax.set_xlabel(r"$\alpha$ / rad") + ax.set_yticks([0, 1000, 2000]) + ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax.set_xticks([0, 1000, 2000]) + ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) counter += 1 - + if title is not None: plt.title(title) plt.tight_layout() - plt.savefig(save_path) \ No newline at end of file + plt.savefig(save_path)