From a496530768e0cb4ce865367667071ca5c51d950b Mon Sep 17 00:00:00 2001 From: = <=> Date: Fri, 22 Nov 2024 11:35:39 +0100 Subject: [PATCH] added save map --- scripts/01_aRTist_bridge.py | 54 ++++--- scripts/02_generate_atlas.py | 23 +-- scripts/03_load_atlas.py | 19 +-- scripts/09_data_aug_pipeline.py | 0 src/cyxtrax/common/__init__.py | 4 +- src/cyxtrax/io/__init__.py | 6 +- src/cyxtrax/io/save_map.py | 52 ++++++ src/cyxtrax/mapping/__init__.py | 5 +- src/cyxtrax/simulation/__init__.py | 7 +- src/cyxtrax/simulation/artist_bridge.py | 3 +- src/cyxtrax/simulation/model/__init__.py | 10 +- src/cyxtrax/util/datasets/__init__.py | 10 +- src/cyxtrax/util/visualisation/__init__.py | 7 +- src/cyxtrax/util/visualisation/plt.py | 174 +++++++++++++-------- 14 files changed, 227 insertions(+), 147 deletions(-) delete mode 100644 scripts/09_data_aug_pipeline.py create mode 100644 src/cyxtrax/io/save_map.py diff --git a/scripts/01_aRTist_bridge.py b/scripts/01_aRTist_bridge.py index 2d93de9..0c99222 100644 --- a/scripts/01_aRTist_bridge.py +++ b/scripts/01_aRTist_bridge.py @@ -8,7 +8,7 @@ from pathlib import Path FOLDER = Path(__file__) -TEMP_FOLDER = FOLDER.parent.parent / 'temp' +TEMP_FOLDER = FOLDER.parent.parent / "temp" def main(): @@ -16,16 +16,14 @@ def main(): # Create the cylindrical projection interface cyl_proj = CylindricalProjection() - # Load the test object 'olaf'. It is automatically installed to your site-packages - print(f'STL Path: {scones_mesh.object_path}') + print(f"STL Path: {scones_mesh.object_path}") # Obect are loaded in a dataclass: MeshObject. - print(f'Mesh is a MeshObject: {isinstance(scones_mesh, MeshObject)}') + print(f"Mesh is a MeshObject: {isinstance(scones_mesh, MeshObject)}") scones_mesh.position_mm = np.array([0, 0, 0]) - scones_mesh.orientation_quat = np.array([0., -0., 0., 0.1]) - - + scones_mesh.orientation_quat = np.array([0.0, -0.0, 0.0, 0.1]) + # The `load_mesh` function load a list of MeshObject load_mesh(cyl_proj, [scones_mesh]) @@ -33,8 +31,12 @@ def main(): map_position = np.array([0, 0, 0]) # Generate local cylindrical projection - half_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=False) - full_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=True) + half_rays = cyl_proj.compute_projection( + map_position, output_full_ray_projection=False + ) + full_rays = cyl_proj.compute_projection( + map_position, output_full_ray_projection=True + ) # code_embedder:A end # Make some nice plots! @@ -42,26 +44,26 @@ def main(): ax1 = fig.add_subplot(121) ax1.imshow(half_rays) - ax1.set_xlabel(r'$Z$ / mm') - ax1.set_ylabel(r'$\alpha$ / rad') - ax1.set_title('Half Rays') - ax1.set_xticks([0, 1000, 2000]) - ax1.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax1.set_yticks([0, 1000, 2000]) - ax1.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax1.set_xlabel(r"$Z$ / mm") + ax1.set_ylabel(r"$\alpha$ / rad") + ax1.set_title("Half Rays") + ax1.set_xticks([0, 1000, 2000]) + ax1.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax1.set_yticks([0, 1000, 2000]) + ax1.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) ax2 = fig.add_subplot(122) ax2.imshow(full_rays) - ax2.set_xlabel(r'$Z$ / mm') - ax2.set_ylabel(r'$\alpha$ / rad') - ax2.set_title('Full Rays') - ax2.set_xticks([0, 1000, 2000]) - ax2.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax2.set_yticks([0, 1000, 2000]) - ax2.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax2.set_xlabel(r"$Z$ / mm") + ax2.set_ylabel(r"$\alpha$ / rad") + ax2.set_title("Full Rays") + ax2.set_xticks([0, 1000, 2000]) + ax2.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax2.set_yticks([0, 1000, 2000]) + ax2.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) - plt.savefig(TEMP_FOLDER / 'scones.png') + plt.savefig(TEMP_FOLDER / "scones.png") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/scripts/02_generate_atlas.py b/scripts/02_generate_atlas.py index d3417c6..a3401d6 100644 --- a/scripts/02_generate_atlas.py +++ b/scripts/02_generate_atlas.py @@ -6,7 +6,7 @@ from pathlib import Path # Just get the temp folder to store the atlas FOLDER = Path(__file__) -TEMP_FOLDER = FOLDER.parent.parent / 'temp' +TEMP_FOLDER = FOLDER.parent.parent / "temp" def main(): @@ -15,13 +15,18 @@ def main(): cyl_proj = CylindricalProjection() load_mesh(cyl_proj, [olaf_mesh]) # make an atlas - - save_path = record_atlas(cyl_proj, [olaf_mesh], 'test', TEMP_FOLDER, - map_bounding_box=[-10, 20], number_of_maps=2) + + save_path = record_atlas( + cyl_proj, + [olaf_mesh], + "test", + TEMP_FOLDER, + map_bounding_box=[-10, 20], + number_of_maps=2, + ) # code_embedder:A end - print(f'Atlas saved to the path: {save_path}') - - + print(f"Atlas saved to the path: {save_path}") + -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/scripts/03_load_atlas.py b/scripts/03_load_atlas.py index f20a8af..e0ed4ed 100644 --- a/scripts/03_load_atlas.py +++ b/scripts/03_load_atlas.py @@ -5,22 +5,23 @@ from pathlib import Path # Just get the temp folder to store the atlas FOLDER = Path(__file__) -TEMP_FOLDER = FOLDER.parent.parent / 'temp' +TEMP_FOLDER = FOLDER.parent.parent / "temp" def main(): # code_embedder:A start - files = list(TEMP_FOLDER.glob('*.h5')) + files = list(TEMP_FOLDER.glob("*.h5")) if not files: - raise FileNotFoundError('No .h5 files in temp folder. Run Example 02!') + raise FileNotFoundError("No .h5 files in temp folder. Run Example 02!") maps, points, mesh_object_list = load_atlas(files[-1]) # code_embedder:A end - print(f'Loaded Map Shape: {maps.shape}') - print(f'Loaded Map Centre Shape: {points.shape}') - print(f'First Object: {mesh_object_list[0]}') + print(f"Loaded Map Shape: {maps.shape}") + print(f"Loaded Map Centre Shape: {points.shape}") + print(f"First Object: {mesh_object_list[0]}") - voxel_maps(maps, TEMP_FOLDER / 'scones_maps.png', title='Scones Atlas') + voxel_maps(maps, TEMP_FOLDER / "scones_maps.png", title="Scones Atlas") -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/scripts/09_data_aug_pipeline.py b/scripts/09_data_aug_pipeline.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/cyxtrax/common/__init__.py b/src/cyxtrax/common/__init__.py index 7c315e4..3b30905 100644 --- a/src/cyxtrax/common/__init__.py +++ b/src/cyxtrax/common/__init__.py @@ -1,6 +1,4 @@ from .mesh_object import MeshObject -__all__ = [ - 'MeshObject' -] \ No newline at end of file +__all__ = ["MeshObject"] diff --git a/src/cyxtrax/io/__init__.py b/src/cyxtrax/io/__init__.py index 2fff89e..f73e110 100644 --- a/src/cyxtrax/io/__init__.py +++ b/src/cyxtrax/io/__init__.py @@ -1,8 +1,6 @@ from .generate_atlas import record_atlas from .load_maps import load_atlas +from .save_map import save_atlas -__all__ = [ - 'record_atlas', - 'load_atlas' -] \ No newline at end of file +__all__ = ["record_atlas", "load_atlas", "save_atlas"] diff --git a/src/cyxtrax/io/save_map.py b/src/cyxtrax/io/save_map.py new file mode 100644 index 0000000..4290707 --- /dev/null +++ b/src/cyxtrax/io/save_map.py @@ -0,0 +1,52 @@ +from cyxtrax.common.mesh_object import MeshObject +from pathlib import Path +import numpy as np +import h5py +import json + + +def to_dict(mesh_object=MeshObject) -> dict: + return mesh_object.as_dict() + + +def save_atlas( + mesh_list: list[MeshObject], + maps: np.ndarray, + map_position: np.ndarray, + atlas_name: str, + save_folder: Path, +): + save_path = save_folder / f"{atlas_name}.h5" + file = h5py.File(save_path, "a") + + x_px = maps.shape[0] + y_px = maps.shape[1] + number_of_maps = maps.shape[2] + + projection = file.require_dataset( + "maps", + shape=( + x_px, + y_px, + number_of_maps, + ), # Initial shape with third dimension as 0 + maxshape=( + x_px, + y_px, + number_of_maps, + ), + dtype=np.float32, + ) + + projection_points = file.require_dataset( + "positions", + shape=(3, number_of_maps**3), # Initial shape with third dimension as 0 + maxshape=(3, number_of_maps**3), + dtype=np.float32, + ) + + mesh_list_dict = list(map(to_dict, mesh_list)) + file.attrs["mesh_list"] = json.dumps(mesh_list_dict) + + projection[:] = maps + projection_points[:] = map_position diff --git a/src/cyxtrax/mapping/__init__.py b/src/cyxtrax/mapping/__init__.py index 122cf03..30956ad 100644 --- a/src/cyxtrax/mapping/__init__.py +++ b/src/cyxtrax/mapping/__init__.py @@ -1,6 +1,3 @@ from .mapping import map_geometry_2_projection, map_source_2_cylinder -__all__ = [ - 'map_geometry_2_projection', - 'map_source_2_cylinder' -] +__all__ = ["map_geometry_2_projection", "map_source_2_cylinder"] diff --git a/src/cyxtrax/simulation/__init__.py b/src/cyxtrax/simulation/__init__.py index ceddce2..3b5fe4e 100644 --- a/src/cyxtrax/simulation/__init__.py +++ b/src/cyxtrax/simulation/__init__.py @@ -1,8 +1,3 @@ from .artist_bridge import CylindricalProjection, load_mesh, SAVEMODES, utility -__all__ = [ - 'CylindricalProjection', - 'load_mesh', - 'SAVEMODES', - 'utility' -] \ No newline at end of file +__all__ = ["CylindricalProjection", "load_mesh", "SAVEMODES", "utility"] diff --git a/src/cyxtrax/simulation/artist_bridge.py b/src/cyxtrax/simulation/artist_bridge.py index 2a472cf..9d35f23 100644 --- a/src/cyxtrax/simulation/artist_bridge.py +++ b/src/cyxtrax/simulation/artist_bridge.py @@ -6,8 +6,10 @@ except ModuleNotFoundError: warn( "The module `artistlib`is not installed. The simulation module is not 100\% usable! \nInstall: https://github.com/wittlsn/aRTist-PythonLib" ) + def API(): return None + utility = None SAVEMODES = None from pathlib import Path @@ -44,7 +46,6 @@ def set_cone_mode(api: API): with importlib.resources.path("cyxtrax.data", "cone.aRTist") as file_path: print(f"Load: {file_path}") api.load_project(file_path) - class CylindricalProjection: diff --git a/src/cyxtrax/simulation/model/__init__.py b/src/cyxtrax/simulation/model/__init__.py index ddd257a..187de9f 100644 --- a/src/cyxtrax/simulation/model/__init__.py +++ b/src/cyxtrax/simulation/model/__init__.py @@ -3,8 +3,8 @@ from .noise import add_white_noise, add_gaussian_blur from .xray import add_material __all__ = [ - 'AugemntationParameter', - 'add_white_noise', - 'add_gaussian_blur', - 'add_material' -] \ No newline at end of file + "AugemntationParameter", + "add_white_noise", + "add_gaussian_blur", + "add_material", +] diff --git a/src/cyxtrax/util/datasets/__init__.py b/src/cyxtrax/util/datasets/__init__.py index 8388c33..3550991 100644 --- a/src/cyxtrax/util/datasets/__init__.py +++ b/src/cyxtrax/util/datasets/__init__.py @@ -2,8 +2,8 @@ from .merge import merge_atlas_and_dict, merge_atlases, merge_lists from .pipeline import generate_atlas_from_folder __all__ = [ - 'merge_atlas_and_dict', - 'merge_atlases', - 'merge_lists', - 'generate_atlas_from_folder' -] \ No newline at end of file + "merge_atlas_and_dict", + "merge_atlases", + "merge_lists", + "generate_atlas_from_folder", +] diff --git a/src/cyxtrax/util/visualisation/__init__.py b/src/cyxtrax/util/visualisation/__init__.py index 9951940..ce653df 100644 --- a/src/cyxtrax/util/visualisation/__init__.py +++ b/src/cyxtrax/util/visualisation/__init__.py @@ -1,9 +1,4 @@ from .plt import save_map_plot_with_index, voxel_cone_maps, voxel_maps from .gif import make_gif -__all__ = [ - 'save_map_plot_with_index', - 'voxel_cone_maps', - 'make_gif', - 'voxel_maps' -] \ No newline at end of file +__all__ = ["save_map_plot_with_index", "voxel_cone_maps", "make_gif", "voxel_maps"] diff --git a/src/cyxtrax/util/visualisation/plt.py b/src/cyxtrax/util/visualisation/plt.py index 29704f6..1fd3e55 100644 --- a/src/cyxtrax/util/visualisation/plt.py +++ b/src/cyxtrax/util/visualisation/plt.py @@ -3,17 +3,18 @@ from jax import numpy as jnp from pathlib import Path -def save_map_plot_with_index(maps: jnp.ndarray, index: int, - save_path: Path, figsize=(8, 8), title: str = None): +def save_map_plot_with_index( + maps: jnp.ndarray, index: int, save_path: Path, figsize=(8, 8), title: str = None +): fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) ax.imshow(maps[:, :, index]) - ax.set_xlabel(r'$z$ / mm') - ax.set_ylabel(r'$\alpha$ / rad') - ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])]) - ax.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])]) - ax.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax.set_xlabel(r"$z$ / mm") + ax.set_ylabel(r"$\alpha$ / rad") + ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])]) + ax.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])]) + ax.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) if title is not None: ax.set_title(title) @@ -22,56 +23,70 @@ def save_map_plot_with_index(maps: jnp.ndarray, index: int, plt.savefig(save_path) plt.close(fig) - -def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray, - cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray, - save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False): + +def voxel_cone_maps( + maps: jnp.ndarray, + projection: jnp.ndarray, + cone_coordinates: jnp.ndarray, + cylindrical_coordinates: jnp.ndarray, + save_path: Path, + figsize=(24, 10), + title: str = None, + vmin: bool = False, +): fig = plt.figure(figsize=figsize) subfig1: figure.Figure subfig2: figure.Figure subfig1, subfig2 = fig.subfigures(1, 2) ax1 = subfig1.add_subplot(111) - ax1.set_title('Projection Mapping') + ax1.set_title("Projection Mapping") im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max()) counter = 0 number_of_maps = maps.shape[2] color = jnp.linspace(0, 1, 9)[:number_of_maps] - cmap = plt.get_cmap('hsv') + cmap = plt.get_cmap("hsv") colors = cmap(color) - - ax1.set_xlabel(r'$u$ / px') - ax1.set_ylabel(r'$v$ / px') - subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm') - + ax1.set_xlabel(r"$u$ / px") + ax1.set_ylabel(r"$v$ / px") + subfig1.colorbar(im, ax=ax1, orientation="vertical", label="Ray Length / mm") + ax2 = subfig2.add_subplot(111) - ax2.set_title('Cylindrical Mapping') + ax2.set_title("Cylindrical Mapping") ax2.axis(False) - side_length = int(jnp.round(number_of_maps ** (1. / 3.))) + side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0))) for z in [0, side_length - 1]: for y in [0, side_length - 1]: for x in [0, side_length - 1]: - position = x + y * side_length + z * side_length ** 2 + position = x + y * side_length + z * side_length**2 if counter >= number_of_maps: continue - ax = subfig2.add_subplot(3, 3, counter+1) + ax = subfig2.add_subplot(3, 3, counter + 1) if vmin: ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max()) else: ax.imshow(maps[:, :, position].T, vmin=maps.min()) - ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position], c=colors[counter]) - ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter]) - ax.set_ylabel(r'$z$ / mm') - ax.set_xlabel(r'$\alpha$ / rad') - ax.set_yticks([0, 1000, 2000]) - ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax.set_xticks([0, 1000, 2000]) - ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax.scatter( + cylindrical_coordinates[0, position], + cylindrical_coordinates[1, position], + c=colors[counter], + ) + ax1.scatter( + cone_coordinates[0, position], + cone_coordinates[1, position], + c=colors[counter], + ) + ax.set_ylabel(r"$z$ / mm") + ax.set_xlabel(r"$\alpha$ / rad") + ax.set_yticks([0, 1000, 2000]) + ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax.set_xticks([0, 1000, 2000]) + ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) counter += 1 @@ -82,53 +97,68 @@ def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray, plt.savefig(save_path) -def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray, - cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray, - cylindrical_values: jnp.ndarray, - save_path: Path, figsize=(24, 10), title: str = None): +def voxel_cone_maps_values( + maps: jnp.ndarray, + projection: jnp.ndarray, + cone_coordinates: jnp.ndarray, + cylindrical_coordinates: jnp.ndarray, + cylindrical_values: jnp.ndarray, + save_path: Path, + figsize=(24, 10), + title: str = None, +): fig = plt.figure(figsize=figsize) subfig1: figure.Figure subfig2: figure.Figure subfig1, subfig2 = fig.subfigures(1, 2) ax1 = subfig1.add_subplot(111) - ax1.set_title('Projection Mapping') + ax1.set_title("Projection Mapping") im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max()) counter = 0 number_of_maps = maps.shape[2] color = jnp.linspace(0, 1, 9)[:number_of_maps] - cmap = plt.get_cmap('hsv') + cmap = plt.get_cmap("hsv") colors = cmap(color) - - ax1.set_xlabel(r'$u$ / px') - ax1.set_ylabel(r'$v$ / px') - subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm') - + ax1.set_xlabel(r"$u$ / px") + ax1.set_ylabel(r"$v$ / px") + subfig1.colorbar(im, ax=ax1, orientation="vertical", label="Ray Length / mm") + ax2 = subfig2.add_subplot(111) - ax2.set_title('Cylindrical Mapping') + ax2.set_title("Cylindrical Mapping") ax2.axis(False) - side_length = int(jnp.round(number_of_maps ** (1. / 3.))) + side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0))) for z in [0, side_length - 1]: for y in [0, side_length - 1]: for x in [0, side_length - 1]: - position = x + y * side_length + z * side_length ** 2 + position = x + y * side_length + z * side_length**2 if counter >= number_of_maps: continue - ax = subfig2.add_subplot(3, 3, counter+1) - ax.scatter(cylindrical_values[0], cylindrical_values[1], ) - ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position], c=colors[counter]) - ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter]) - ax.set_ylabel(r'$z$ / mm') - ax.set_xlabel(r'$\alpha$ / rad') - ax.set_yticks([0, 1000, 2000]) - ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax.set_xticks([0, 1000, 2000]) - ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) - + ax = subfig2.add_subplot(3, 3, counter + 1) + ax.scatter( + cylindrical_values[0], + cylindrical_values[1], + ) + ax.scatter( + cylindrical_coordinates[0, position], + cylindrical_coordinates[1, position], + c=colors[counter], + ) + ax1.scatter( + cone_coordinates[0, position], + cone_coordinates[1, position], + c=colors[counter], + ) + ax.set_ylabel(r"$z$ / mm") + ax.set_xlabel(r"$\alpha$ / rad") + ax.set_yticks([0, 1000, 2000]) + ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax.set_xticks([0, 1000, 2000]) + ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) counter += 1 @@ -139,38 +169,44 @@ def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray, plt.savefig(save_path) -def voxel_maps(maps: jnp.ndarray, save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False): +def voxel_maps( + maps: jnp.ndarray, + save_path: Path, + figsize=(24, 10), + title: str = None, + vmin: bool = False, +): fig = plt.figure(figsize=figsize) ax2 = fig.add_subplot(111) - ax2.set_title('Cylindrical Mapping') + ax2.set_title("Cylindrical Mapping") ax2.axis(False) number_of_maps = maps.shape[2] - side_length = int(jnp.round(number_of_maps ** (1. / 3.))) + side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0))) counter = 0 for z in [0, side_length - 1]: for y in [0, side_length - 1]: for x in [0, side_length - 1]: - position = x + y * side_length + z * side_length ** 2 + position = x + y * side_length + z * side_length**2 if counter >= number_of_maps: continue - ax = fig.add_subplot(3, 3, counter+1) + ax = fig.add_subplot(3, 3, counter + 1) if vmin: ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max()) else: ax.imshow(maps[:, :, position].T, vmin=maps.min()) - ax.set_ylabel(r'$z$ / mm') - ax.set_xlabel(r'$\alpha$ / rad') - ax.set_yticks([0, 1000, 2000]) - ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) - ax.set_xticks([0, 1000, 2000]) - ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + ax.set_ylabel(r"$z$ / mm") + ax.set_xlabel(r"$\alpha$ / rad") + ax.set_yticks([0, 1000, 2000]) + ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"]) + ax.set_xticks([0, 1000, 2000]) + ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"]) counter += 1 - + if title is not None: plt.title(title) plt.tight_layout() - plt.savefig(save_path) \ No newline at end of file + plt.savefig(save_path) -- GitLab