Skip to content
Snippets Groups Projects
Commit a4965307 authored by ='s avatar =
Browse files

added save map

parent e6066370
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ from pathlib import Path
FOLDER = Path(__file__)
TEMP_FOLDER = FOLDER.parent.parent / 'temp'
TEMP_FOLDER = FOLDER.parent.parent / "temp"
def main():
......@@ -16,16 +16,14 @@ def main():
# Create the cylindrical projection interface
cyl_proj = CylindricalProjection()
# Load the test object 'olaf'. It is automatically installed to your site-packages
print(f'STL Path: {scones_mesh.object_path}')
print(f"STL Path: {scones_mesh.object_path}")
# Obect are loaded in a dataclass: MeshObject.
print(f'Mesh is a MeshObject: {isinstance(scones_mesh, MeshObject)}')
print(f"Mesh is a MeshObject: {isinstance(scones_mesh, MeshObject)}")
scones_mesh.position_mm = np.array([0, 0, 0])
scones_mesh.orientation_quat = np.array([0., -0., 0., 0.1])
scones_mesh.orientation_quat = np.array([0.0, -0.0, 0.0, 0.1])
# The `load_mesh` function load a list of MeshObject
load_mesh(cyl_proj, [scones_mesh])
......@@ -33,8 +31,12 @@ def main():
map_position = np.array([0, 0, 0])
# Generate local cylindrical projection
half_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=False)
full_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=True)
half_rays = cyl_proj.compute_projection(
map_position, output_full_ray_projection=False
)
full_rays = cyl_proj.compute_projection(
map_position, output_full_ray_projection=True
)
# code_embedder:A end
# Make some nice plots!
......@@ -42,26 +44,26 @@ def main():
ax1 = fig.add_subplot(121)
ax1.imshow(half_rays)
ax1.set_xlabel(r'$Z$ / mm')
ax1.set_ylabel(r'$\alpha$ / rad')
ax1.set_title('Half Rays')
ax1.set_xticks([0, 1000, 2000])
ax1.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$'])
ax1.set_yticks([0, 1000, 2000])
ax1.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$'])
ax1.set_xlabel(r"$Z$ / mm")
ax1.set_ylabel(r"$\alpha$ / rad")
ax1.set_title("Half Rays")
ax1.set_xticks([0, 1000, 2000])
ax1.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
ax1.set_yticks([0, 1000, 2000])
ax1.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
ax2 = fig.add_subplot(122)
ax2.imshow(full_rays)
ax2.set_xlabel(r'$Z$ / mm')
ax2.set_ylabel(r'$\alpha$ / rad')
ax2.set_title('Full Rays')
ax2.set_xticks([0, 1000, 2000])
ax2.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$'])
ax2.set_yticks([0, 1000, 2000])
ax2.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$'])
ax2.set_xlabel(r"$Z$ / mm")
ax2.set_ylabel(r"$\alpha$ / rad")
ax2.set_title("Full Rays")
ax2.set_xticks([0, 1000, 2000])
ax2.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
ax2.set_yticks([0, 1000, 2000])
ax2.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
plt.savefig(TEMP_FOLDER / 'scones.png')
plt.savefig(TEMP_FOLDER / "scones.png")
if __name__ == '__main__':
main()
\ No newline at end of file
if __name__ == "__main__":
main()
......@@ -6,7 +6,7 @@ from pathlib import Path
# Just get the temp folder to store the atlas
FOLDER = Path(__file__)
TEMP_FOLDER = FOLDER.parent.parent / 'temp'
TEMP_FOLDER = FOLDER.parent.parent / "temp"
def main():
......@@ -15,13 +15,18 @@ def main():
cyl_proj = CylindricalProjection()
load_mesh(cyl_proj, [olaf_mesh])
# make an atlas
save_path = record_atlas(cyl_proj, [olaf_mesh], 'test', TEMP_FOLDER,
map_bounding_box=[-10, 20], number_of_maps=2)
save_path = record_atlas(
cyl_proj,
[olaf_mesh],
"test",
TEMP_FOLDER,
map_bounding_box=[-10, 20],
number_of_maps=2,
)
# code_embedder:A end
print(f'Atlas saved to the path: {save_path}')
print(f"Atlas saved to the path: {save_path}")
if __name__ == '__main__':
main()
\ No newline at end of file
if __name__ == "__main__":
main()
......@@ -5,22 +5,23 @@ from pathlib import Path
# Just get the temp folder to store the atlas
FOLDER = Path(__file__)
TEMP_FOLDER = FOLDER.parent.parent / 'temp'
TEMP_FOLDER = FOLDER.parent.parent / "temp"
def main():
# code_embedder:A start
files = list(TEMP_FOLDER.glob('*.h5'))
files = list(TEMP_FOLDER.glob("*.h5"))
if not files:
raise FileNotFoundError('No .h5 files in temp folder. Run Example 02!')
raise FileNotFoundError("No .h5 files in temp folder. Run Example 02!")
maps, points, mesh_object_list = load_atlas(files[-1])
# code_embedder:A end
print(f'Loaded Map Shape: {maps.shape}')
print(f'Loaded Map Centre Shape: {points.shape}')
print(f'First Object: {mesh_object_list[0]}')
print(f"Loaded Map Shape: {maps.shape}")
print(f"Loaded Map Centre Shape: {points.shape}")
print(f"First Object: {mesh_object_list[0]}")
voxel_maps(maps, TEMP_FOLDER / 'scones_maps.png', title='Scones Atlas')
voxel_maps(maps, TEMP_FOLDER / "scones_maps.png", title="Scones Atlas")
if __name__ == '__main__':
main()
\ No newline at end of file
if __name__ == "__main__":
main()
from .mesh_object import MeshObject
__all__ = [
'MeshObject'
]
\ No newline at end of file
__all__ = ["MeshObject"]
from .generate_atlas import record_atlas
from .load_maps import load_atlas
from .save_map import save_atlas
__all__ = [
'record_atlas',
'load_atlas'
]
\ No newline at end of file
__all__ = ["record_atlas", "load_atlas", "save_atlas"]
from cyxtrax.common.mesh_object import MeshObject
from pathlib import Path
import numpy as np
import h5py
import json
def to_dict(mesh_object=MeshObject) -> dict:
return mesh_object.as_dict()
def save_atlas(
mesh_list: list[MeshObject],
maps: np.ndarray,
map_position: np.ndarray,
atlas_name: str,
save_folder: Path,
):
save_path = save_folder / f"{atlas_name}.h5"
file = h5py.File(save_path, "a")
x_px = maps.shape[0]
y_px = maps.shape[1]
number_of_maps = maps.shape[2]
projection = file.require_dataset(
"maps",
shape=(
x_px,
y_px,
number_of_maps,
), # Initial shape with third dimension as 0
maxshape=(
x_px,
y_px,
number_of_maps,
),
dtype=np.float32,
)
projection_points = file.require_dataset(
"positions",
shape=(3, number_of_maps**3), # Initial shape with third dimension as 0
maxshape=(3, number_of_maps**3),
dtype=np.float32,
)
mesh_list_dict = list(map(to_dict, mesh_list))
file.attrs["mesh_list"] = json.dumps(mesh_list_dict)
projection[:] = maps
projection_points[:] = map_position
from .mapping import map_geometry_2_projection, map_source_2_cylinder
__all__ = [
'map_geometry_2_projection',
'map_source_2_cylinder'
]
__all__ = ["map_geometry_2_projection", "map_source_2_cylinder"]
from .artist_bridge import CylindricalProjection, load_mesh, SAVEMODES, utility
__all__ = [
'CylindricalProjection',
'load_mesh',
'SAVEMODES',
'utility'
]
\ No newline at end of file
__all__ = ["CylindricalProjection", "load_mesh", "SAVEMODES", "utility"]
......@@ -6,8 +6,10 @@ except ModuleNotFoundError:
warn(
"The module `artistlib`is not installed. The simulation module is not 100\% usable! \nInstall: https://github.com/wittlsn/aRTist-PythonLib"
)
def API():
return None
utility = None
SAVEMODES = None
from pathlib import Path
......@@ -44,7 +46,6 @@ def set_cone_mode(api: API):
with importlib.resources.path("cyxtrax.data", "cone.aRTist") as file_path:
print(f"Load: {file_path}")
api.load_project(file_path)
class CylindricalProjection:
......
......@@ -3,8 +3,8 @@ from .noise import add_white_noise, add_gaussian_blur
from .xray import add_material
__all__ = [
'AugemntationParameter',
'add_white_noise',
'add_gaussian_blur',
'add_material'
]
\ No newline at end of file
"AugemntationParameter",
"add_white_noise",
"add_gaussian_blur",
"add_material",
]
......@@ -2,8 +2,8 @@ from .merge import merge_atlas_and_dict, merge_atlases, merge_lists
from .pipeline import generate_atlas_from_folder
__all__ = [
'merge_atlas_and_dict',
'merge_atlases',
'merge_lists',
'generate_atlas_from_folder'
]
\ No newline at end of file
"merge_atlas_and_dict",
"merge_atlases",
"merge_lists",
"generate_atlas_from_folder",
]
from .plt import save_map_plot_with_index, voxel_cone_maps, voxel_maps
from .gif import make_gif
__all__ = [
'save_map_plot_with_index',
'voxel_cone_maps',
'make_gif',
'voxel_maps'
]
\ No newline at end of file
__all__ = ["save_map_plot_with_index", "voxel_cone_maps", "make_gif", "voxel_maps"]
......@@ -3,17 +3,18 @@ from jax import numpy as jnp
from pathlib import Path
def save_map_plot_with_index(maps: jnp.ndarray, index: int,
save_path: Path, figsize=(8, 8), title: str = None):
def save_map_plot_with_index(
maps: jnp.ndarray, index: int, save_path: Path, figsize=(8, 8), title: str = None
):
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.imshow(maps[:, :, index])
ax.set_xlabel(r'$z$ / mm')
ax.set_ylabel(r'$\alpha$ / rad')
ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])])
ax.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$'])
ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])])
ax.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$'])
ax.set_xlabel(r"$z$ / mm")
ax.set_ylabel(r"$\alpha$ / rad")
ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])])
ax.set_xticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])])
ax.set_yticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
if title is not None:
ax.set_title(title)
......@@ -22,56 +23,70 @@ def save_map_plot_with_index(maps: jnp.ndarray, index: int,
plt.savefig(save_path)
plt.close(fig)
def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray,
cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray,
save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False):
def voxel_cone_maps(
maps: jnp.ndarray,
projection: jnp.ndarray,
cone_coordinates: jnp.ndarray,
cylindrical_coordinates: jnp.ndarray,
save_path: Path,
figsize=(24, 10),
title: str = None,
vmin: bool = False,
):
fig = plt.figure(figsize=figsize)
subfig1: figure.Figure
subfig2: figure.Figure
subfig1, subfig2 = fig.subfigures(1, 2)
ax1 = subfig1.add_subplot(111)
ax1.set_title('Projection Mapping')
ax1.set_title("Projection Mapping")
im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max())
counter = 0
number_of_maps = maps.shape[2]
color = jnp.linspace(0, 1, 9)[:number_of_maps]
cmap = plt.get_cmap('hsv')
cmap = plt.get_cmap("hsv")
colors = cmap(color)
ax1.set_xlabel(r'$u$ / px')
ax1.set_ylabel(r'$v$ / px')
subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm')
ax1.set_xlabel(r"$u$ / px")
ax1.set_ylabel(r"$v$ / px")
subfig1.colorbar(im, ax=ax1, orientation="vertical", label="Ray Length / mm")
ax2 = subfig2.add_subplot(111)
ax2.set_title('Cylindrical Mapping')
ax2.set_title("Cylindrical Mapping")
ax2.axis(False)
side_length = int(jnp.round(number_of_maps ** (1. / 3.)))
side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0)))
for z in [0, side_length - 1]:
for y in [0, side_length - 1]:
for x in [0, side_length - 1]:
position = x + y * side_length + z * side_length ** 2
position = x + y * side_length + z * side_length**2
if counter >= number_of_maps:
continue
ax = subfig2.add_subplot(3, 3, counter+1)
ax = subfig2.add_subplot(3, 3, counter + 1)
if vmin:
ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max())
else:
ax.imshow(maps[:, :, position].T, vmin=maps.min())
ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position], c=colors[counter])
ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter])
ax.set_ylabel(r'$z$ / mm')
ax.set_xlabel(r'$\alpha$ / rad')
ax.set_yticks([0, 1000, 2000])
ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$'])
ax.set_xticks([0, 1000, 2000])
ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$'])
ax.scatter(
cylindrical_coordinates[0, position],
cylindrical_coordinates[1, position],
c=colors[counter],
)
ax1.scatter(
cone_coordinates[0, position],
cone_coordinates[1, position],
c=colors[counter],
)
ax.set_ylabel(r"$z$ / mm")
ax.set_xlabel(r"$\alpha$ / rad")
ax.set_yticks([0, 1000, 2000])
ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
ax.set_xticks([0, 1000, 2000])
ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
counter += 1
......@@ -82,53 +97,68 @@ def voxel_cone_maps(maps: jnp.ndarray, projection: jnp.ndarray,
plt.savefig(save_path)
def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray,
cone_coordinates: jnp.ndarray, cylindrical_coordinates: jnp.ndarray,
cylindrical_values: jnp.ndarray,
save_path: Path, figsize=(24, 10), title: str = None):
def voxel_cone_maps_values(
maps: jnp.ndarray,
projection: jnp.ndarray,
cone_coordinates: jnp.ndarray,
cylindrical_coordinates: jnp.ndarray,
cylindrical_values: jnp.ndarray,
save_path: Path,
figsize=(24, 10),
title: str = None,
):
fig = plt.figure(figsize=figsize)
subfig1: figure.Figure
subfig2: figure.Figure
subfig1, subfig2 = fig.subfigures(1, 2)
ax1 = subfig1.add_subplot(111)
ax1.set_title('Projection Mapping')
ax1.set_title("Projection Mapping")
im = ax1.imshow(projection, vmin=maps.min(), vmax=maps.max())
counter = 0
number_of_maps = maps.shape[2]
color = jnp.linspace(0, 1, 9)[:number_of_maps]
cmap = plt.get_cmap('hsv')
cmap = plt.get_cmap("hsv")
colors = cmap(color)
ax1.set_xlabel(r'$u$ / px')
ax1.set_ylabel(r'$v$ / px')
subfig1.colorbar(im, ax=ax1, orientation='vertical', label='Ray Length / mm')
ax1.set_xlabel(r"$u$ / px")
ax1.set_ylabel(r"$v$ / px")
subfig1.colorbar(im, ax=ax1, orientation="vertical", label="Ray Length / mm")
ax2 = subfig2.add_subplot(111)
ax2.set_title('Cylindrical Mapping')
ax2.set_title("Cylindrical Mapping")
ax2.axis(False)
side_length = int(jnp.round(number_of_maps ** (1. / 3.)))
side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0)))
for z in [0, side_length - 1]:
for y in [0, side_length - 1]:
for x in [0, side_length - 1]:
position = x + y * side_length + z * side_length ** 2
position = x + y * side_length + z * side_length**2
if counter >= number_of_maps:
continue
ax = subfig2.add_subplot(3, 3, counter+1)
ax.scatter(cylindrical_values[0], cylindrical_values[1], )
ax.scatter(cylindrical_coordinates[0, position], cylindrical_coordinates[1, position], c=colors[counter])
ax1.scatter(cone_coordinates[0, position], cone_coordinates[1, position], c=colors[counter])
ax.set_ylabel(r'$z$ / mm')
ax.set_xlabel(r'$\alpha$ / rad')
ax.set_yticks([0, 1000, 2000])
ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$'])
ax.set_xticks([0, 1000, 2000])
ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$'])
ax = subfig2.add_subplot(3, 3, counter + 1)
ax.scatter(
cylindrical_values[0],
cylindrical_values[1],
)
ax.scatter(
cylindrical_coordinates[0, position],
cylindrical_coordinates[1, position],
c=colors[counter],
)
ax1.scatter(
cone_coordinates[0, position],
cone_coordinates[1, position],
c=colors[counter],
)
ax.set_ylabel(r"$z$ / mm")
ax.set_xlabel(r"$\alpha$ / rad")
ax.set_yticks([0, 1000, 2000])
ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
ax.set_xticks([0, 1000, 2000])
ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
counter += 1
......@@ -139,38 +169,44 @@ def voxel_cone_maps_values(maps: jnp.ndarray, projection: jnp.ndarray,
plt.savefig(save_path)
def voxel_maps(maps: jnp.ndarray, save_path: Path, figsize=(24, 10), title: str = None, vmin: bool = False):
def voxel_maps(
maps: jnp.ndarray,
save_path: Path,
figsize=(24, 10),
title: str = None,
vmin: bool = False,
):
fig = plt.figure(figsize=figsize)
ax2 = fig.add_subplot(111)
ax2.set_title('Cylindrical Mapping')
ax2.set_title("Cylindrical Mapping")
ax2.axis(False)
number_of_maps = maps.shape[2]
side_length = int(jnp.round(number_of_maps ** (1. / 3.)))
side_length = int(jnp.round(number_of_maps ** (1.0 / 3.0)))
counter = 0
for z in [0, side_length - 1]:
for y in [0, side_length - 1]:
for x in [0, side_length - 1]:
position = x + y * side_length + z * side_length ** 2
position = x + y * side_length + z * side_length**2
if counter >= number_of_maps:
continue
ax = fig.add_subplot(3, 3, counter+1)
ax = fig.add_subplot(3, 3, counter + 1)
if vmin:
ax.imshow(maps[:, :, position].T, vmin=maps.min(), vmax=maps.max())
else:
ax.imshow(maps[:, :, position].T, vmin=maps.min())
ax.set_ylabel(r'$z$ / mm')
ax.set_xlabel(r'$\alpha$ / rad')
ax.set_yticks([0, 1000, 2000])
ax.set_yticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$'])
ax.set_xticks([0, 1000, 2000])
ax.set_xticklabels([r'$-\pi$', r'$0.$', r'$\pi$'])
ax.set_ylabel(r"$z$ / mm")
ax.set_xlabel(r"$\alpha$ / rad")
ax.set_yticks([0, 1000, 2000])
ax.set_yticklabels([r"$-3141.5$", r"$0.$", r"$3141.5$"])
ax.set_xticks([0, 1000, 2000])
ax.set_xticklabels([r"$-\pi$", r"$0.$", r"$\pi$"])
counter += 1
if title is not None:
plt.title(title)
plt.tight_layout()
plt.savefig(save_path)
\ No newline at end of file
plt.savefig(save_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment