From 5688f0fd9bdbd424db480de50720c3120a74544b Mon Sep 17 00:00:00 2001 From: swittl <simon.wittl@th-deg.de> Date: Mon, 4 Nov 2024 08:51:17 +0100 Subject: [PATCH] added mapping example --- CyXTraX/io/generate_atlas.py | 20 ++-- CyXTraX/mapping/mapping.py | 4 + CyXTraX/util/visualisation/__init__.py | 3 +- CyXTraX/util/visualisation/gif.py | 18 +++- CyXTraX/util/visualisation/plt.py | 122 ++++++++++++++++++++++++- scripts/04_mapping.py | 4 +- scripts/06_generate_dataset.py | 4 +- scripts/07_presentation.py | 71 ++++++++++++++ scripts/08_make_gif.py | 13 +++ 9 files changed, 239 insertions(+), 20 deletions(-) create mode 100644 scripts/07_presentation.py create mode 100644 scripts/08_make_gif.py diff --git a/CyXTraX/io/generate_atlas.py b/CyXTraX/io/generate_atlas.py index 239f985..1df4b4f 100644 --- a/CyXTraX/io/generate_atlas.py +++ b/CyXTraX/io/generate_atlas.py @@ -21,14 +21,14 @@ def record_atlas(cylindrical_projection: CylindricalProjection, mesh_list: list[ file = h5py.File(save_path, 'a') projection = file.require_dataset( 'maps', - shape=(cylindrical_projection.x_px, cylindrical_projection.y_px, 0), # Initial shape with third dimension as 0 - maxshape=(cylindrical_projection.x_px, cylindrical_projection.y_px, None), + shape=(cylindrical_projection.x_px, cylindrical_projection.y_px, number_of_maps**3), # Initial shape with third dimension as 0 + maxshape=(cylindrical_projection.x_px, cylindrical_projection.y_px, number_of_maps**3), dtype=np.float32) projection_points = file.require_dataset( 'positions', - shape=(3, 0), # Initial shape with third dimension as 0 - maxshape=(3, None), + shape=(3, number_of_maps**3), # Initial shape with third dimension as 0 + maxshape=(3, number_of_maps**3), dtype=np.float32) @@ -38,7 +38,7 @@ def record_atlas(cylindrical_projection: CylindricalProjection, mesh_list: list[ file.attrs['mesh_list'] = json.dumps(mesh_list_dict) - + counter = 0 for x in grid: for y in grid: for z in grid: @@ -49,10 +49,10 @@ def record_atlas(cylindrical_projection: CylindricalProjection, mesh_list: list[ image = cylindrical_projection.compute_projection(position, output_full_ray_projection=True) current_size = projection.shape[2] new_size = current_size + 1 - projection.resize((cylindrical_projection.x_px, cylindrical_projection.y_px, new_size)) - projection[:, :, current_size:new_size] = image[:, :, np.newaxis].astype(np.float32) - - projection_points.resize((3, new_size)) - projection_points[:, current_size:new_size] = position.reshape((3, 1)) + #projection.resize((cylindrical_projection.x_px, cylindrical_projection.y_px, new_size)) + projection[:, :, counter:counter+1] = image[:, :, np.newaxis].astype(np.float32) + #projection_points.resize((3, new_size)) + projection_points[:, counter:counter+1] = position.reshape((3, 1)) + counter += 1 return save_path diff --git a/CyXTraX/mapping/mapping.py b/CyXTraX/mapping/mapping.py index f697a51..8a153a6 100644 --- a/CyXTraX/mapping/mapping.py +++ b/CyXTraX/mapping/mapping.py @@ -71,6 +71,10 @@ def map_geometry_2_projection(source: jnp.ndarray, detector: jnp.ndarray, detect projection_value: m """ + source = source.reshape((1, 3)) + detector = detector.reshape((1, 3)) + detector_orientation = detector_orientation.reshape((1, 4)) + projection_geometry = projection_matrix(source, detector, detector_orientation) # 3 x 4 map_positions_homogen = jnp.ones((1, map_positions.shape[1])) # 1 x m diff --git a/CyXTraX/util/visualisation/__init__.py b/CyXTraX/util/visualisation/__init__.py index a637804..589931f 100644 --- a/CyXTraX/util/visualisation/__init__.py +++ b/CyXTraX/util/visualisation/__init__.py @@ -1 +1,2 @@ -from .plt import save_map_plot_with_index \ No newline at end of file +from .plt import save_map_plot_with_index, voxel_cone_maps +from .gif import make_gif \ No newline at end of file diff --git a/CyXTraX/util/visualisation/gif.py b/CyXTraX/util/visualisation/gif.py index 146a4fc..68e0a04 100644 --- a/CyXTraX/util/visualisation/gif.py +++ b/CyXTraX/util/visualisation/gif.py @@ -1,7 +1,23 @@ from CyXTraX.io import load_atlas from pathlib import Path +from PIL import Image def atlas_gif(load_path: Path): maps, map_position, meshes = load_atlas(load_path) - \ No newline at end of file + + +def load_pil(path: Path) -> Image.Image: + return Image.open(path) + + +def make_gif(load_path: Path, pattern: str, name: str, duration_ms: float = 100): + frames = sorted(list(load_path.glob(pattern))) + frames = list(map(load_pil, frames)) + frames[0].save( + load_path / name, + save_all=True, + append_images=frames[1:], + duration=duration_ms, + loop=0) + diff --git a/CyXTraX/util/visualisation/plt.py b/CyXTraX/util/visualisation/plt.py index c992c37..799e528 100644 --- a/CyXTraX/util/visualisation/plt.py +++ b/CyXTraX/util/visualisation/plt.py @@ -1,10 +1,10 @@ -from matplotlib import pyplot as plt +from matplotlib import pyplot as plt, figure from jax import numpy as jnp from pathlib import Path -def save_map_plot_with_index(maps: jnp.ndarray, index: int, save_path: Path, figsize=(8, 8), - title: str = None): +def save_map_plot_with_index(maps: jnp.ndarray, index: int, + save_path: Path, figsize=(8, 8), title: str = None): fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) ax.imshow(maps[:, :, index]) @@ -22,4 +22,118 @@ def save_map_plot_with_index(maps: jnp.ndarray, index: int, save_path: Path, fig plt.savefig(save_path) plt.close(fig) - \ No newline at end of file + + +def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray, + cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray, + save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False): + fig = plt.figure(figsize=figsize) + subfig1: figure.Figure + subfig2: figure.Figure + subfig1, subfig2 = fig.subfigures(1, 2) + ax1 = subfig1.add_subplot(111) + ax1.set_title('Projection Mapping') + im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max()) + + counter = 0 + number_of_maps = maps.shape[2] + + color = jnp.linspace(0, 1, 9)[:number_of_maps] + cmap = plt.get_cmap('hsv') + colors = cmap(color) + + ax1.set_xlabel(r'$u$ / px') + ax1.set_ylabel(r'$v$ / px') + subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm') + + + ax2 = subfig2.add_subplot(111) + ax2.set_title('Cylindrical Mapping') + ax2.axis(False) + + side_length = int(jnp.round(number_of_maps ** (1. / 3.))) + + for z in [0, side_length - 1]: + for y in [0, side_length - 1]: + for x in [0, side_length - 1]: + position = x + y * side_length + z * side_length ** 2 + if counter >= number_of_maps: + continue + ax = subfig2.add_subplot(3, 3, counter+1) + if vmin: + ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max()) + else: + ax.imshow(maps[:, :, position].T, vmin=maps.min()) + ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position], c=colors[counter]) + ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter]) + ax.set_ylabel(r'$z$ / mm') + ax.set_xlabel(r'$\alpha$ / rad') + ax.set_yticks([0, 1000, 2000]) + ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) + ax.set_xticks([0, 1000, 2000]) + ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + + counter += 1 + + if title is not None: + plt.title(title) + + plt.tight_layout() + plt.savefig(save_path) + + +def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray, + cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray, + cylindrical_values: jnp.ndarray, + save_path: Path, figsize=(24, 10), title: str = None): + fig = plt.figure(figsize=figsize) + subfig1: figure.Figure + subfig2: figure.Figure + subfig1, subfig2 = fig.subfigures(1, 2) + ax1 = subfig1.add_subplot(111) + ax1.set_title('Projection Mapping') + im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max()) + + counter = 0 + number_of_maps = maps.shape[2] + + color = jnp.linspace(0, 1, 9)[:number_of_maps] + cmap = plt.get_cmap('hsv') + colors = cmap(color) + + ax1.set_xlabel(r'$u$ / px') + ax1.set_ylabel(r'$v$ / px') + subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm') + + + ax2 = subfig2.add_subplot(111) + ax2.set_title('Cylindrical Mapping') + ax2.axis(False) + + side_length = int(jnp.round(number_of_maps ** (1. / 3.))) + + for z in [0, side_length - 1]: + for y in [0, side_length - 1]: + for x in [0, side_length - 1]: + position = x + y * side_length + z * side_length ** 2 + if counter >= number_of_maps: + continue + ax = subfig2.add_subplot(3, 3, counter+1) + ax.scatter(cylindrical_values[0], cylindrical_values[1], ) + ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position], c=colors[counter]) + ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter]) + ax.set_ylabel(r'$z$ / mm') + ax.set_xlabel(r'$\alpha$ / rad') + ax.set_yticks([0, 1000, 2000]) + ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) + ax.set_xticks([0, 1000, 2000]) + ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + + + counter += 1 + + if title is not None: + plt.title(title) + + plt.tight_layout() + plt.savefig(save_path) \ No newline at end of file diff --git a/scripts/04_mapping.py b/scripts/04_mapping.py index 657fe50..45f729b 100644 --- a/scripts/04_mapping.py +++ b/scripts/04_mapping.py @@ -82,8 +82,8 @@ def main(): ax2 = subfig2.add_subplot(111) ax2.set_title('Cylindrical Mapping') ax2.axis(False) - for i in range(3): - for j in range(3): + for _ in range(3): + for _ in range(3): if counter >= number_of_maps: continue ax = subfig2.add_subplot(3, 3, counter+1) diff --git a/scripts/06_generate_dataset.py b/scripts/06_generate_dataset.py index 516e803..d3fb72f 100644 --- a/scripts/06_generate_dataset.py +++ b/scripts/06_generate_dataset.py @@ -11,7 +11,7 @@ and needs 100gb of space. def main(): # Dataset save folder - save_folder = Path(r'C:\data\XrayTransform') + save_folder = Path(r'C:\data\XrayTransformLow') # Dataset of multiple random stl files from: https://ten-thousand-models.appspot.com/ load_folder = Path(r'C:\Users\swittl\Downloads\Thingi10K\Thingi10K\raw_meshes') @@ -22,7 +22,7 @@ def main(): print(i) key = random.PRNGKey(i) generate_atlas_from_folder( - load_folder, key, save_folder) + load_folder, key, save_folder, map_bounding_box=(-10, 10), number_of_maps=2) if __name__ == '__main__': diff --git a/scripts/07_presentation.py b/scripts/07_presentation.py new file mode 100644 index 0000000..8540a44 --- /dev/null +++ b/scripts/07_presentation.py @@ -0,0 +1,71 @@ +from CyXTraX.simulation import CylindricalProjection, load_mesh, SAVEMODES, utility +from CyXTraX.io import load_atlas +from CyXTraX.mapping import map_geometry_2_projection, map_source_2_cylinder +from CyXTraX.util.visualisation import voxel_cone_maps + +from pathlib import Path +from jax import numpy as jnp +from jax.scipy.spatial.transform import Rotation + +from cyxtrax_object import set_geometry_at_index, circular_trajectory, add_world_orign, zero_quat, x_90_quat, z_90_quat, add_offset + + + +# Just get the temp folder to store the atlas +FOLDER = Path(__file__) +TEMP_FOLDER = FOLDER.parent.parent / 'temp' +MAP_FOLDER = Path(r'C:\data\XrayTransformLow') + + +def main(): + files = list(MAP_FOLDER.glob('*.h5')) + if not files: + raise FileNotFoundError('No .h5 files in temp folder. Run Example 02!') + + maps, points, mesh_object_list = load_atlas(files[3]) + cyl_proj = CylindricalProjection() + + load_mesh(cyl_proj, mesh_object_list, False) + print(f'Cone Beam Mode set: {cyl_proj.cone_mode}') + + + number_of_projections = 30 + + fod_mm = jnp.array([1000,]) + fdd_mm = jnp.array([2000,]) + alpha_rad = jnp.linspace(0, jnp.pi, number_of_projections).reshape((-1, 1)) + + source, orientation_source, detector, orientation_detector = circular_trajectory(fod_mm, fdd_mm, alpha_rad) + source, orientation_source, detector, orientation_detector = add_world_orign( + source, orientation_source, detector, orientation_detector, jnp.array([1, 2., -1.]), jnp.array([0.8, 0.1, -0.1, 0.9])) + detector, orientation_detector = add_offset(detector, orientation_detector, jnp.array([1, 20., -1.]), zero_quat) + + temp_projection = Path('temp.tif') + + + + for i in range(number_of_projections): + set_geometry_at_index(source, orientation_source, detector, orientation_detector, i, temp_projection) + + projection, geometry = utility.load_projection(temp_projection) + mapped_values, uv_px = map_geometry_2_projection( + source[i], + detector[i], + orientation_detector[i], + points, + projection, + cyl_proj.pixel_pitch_mm) + + values_calc, angles_calc = map_source_2_cylinder( + source[i], + maps, + points, + cyl_proj.detector_radius_mm, + cyl_proj.x_px, + cyl_proj.x_resolution_mm) + + voxel_cone_maps(maps, projection, uv_px, angles_calc, TEMP_FOLDER / f'{i:03}_presentation.png') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/08_make_gif.py b/scripts/08_make_gif.py new file mode 100644 index 0000000..2498d0c --- /dev/null +++ b/scripts/08_make_gif.py @@ -0,0 +1,13 @@ +from CyXTraX.util.visualisation import make_gif +from pathlib import Path + +# Just get the temp folder to store the atlas +FOLDER = Path(__file__) +TEMP_FOLDER = FOLDER.parent.parent / 'temp' + +def main(): + make_gif(TEMP_FOLDER, '*_presentation.png', 'cyxtrax.gif', 333) + + +if __name__ == '__main__': + main() \ No newline at end of file -- GitLab