diff --git a/.gitignore b/.gitignore index a6b232ae238214a64d2a431d84f258e751d37e5d..7ead58b674d77f1e0a8365223a6782df4aa527e4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /build/ /temp/ __pycache__ -*.egg-info \ No newline at end of file +*.egg-info +*temp.* \ No newline at end of file diff --git a/CyXTraX/__init__.py b/CyXTraX/__init__.py index d45a1322301b4c1dfe92aa53f58f27f6d39587be..604e872795dde098fda58faa76811b31382388e1 100644 --- a/CyXTraX/__init__.py +++ b/CyXTraX/__init__.py @@ -4,12 +4,17 @@ from .generate_atlas import record_atlas from .load_maps import load_atlas import importlib.resources +import numpy as np with importlib.resources.path('CyXTraX.data', 'olaf_6.stl') as file_path: olaf_stl = file_path - import numpy as np - olaf_mesh = MeshObject(olaf_stl, position_mm=np.array([1, 2, 3]), - orientation_quat=np.array([0., 0., 0., 1.])) \ No newline at end of file + orientation_quat=np.array([0., 0., 0., 1.])) + +with importlib.resources.path('CyXTraX.data', 'scones.stl') as file_path: + scones_stl = file_path + scones_mesh = MeshObject(scones_stl, + position_mm=np.array([0, 0, 0]), + orientation_quat=np.array([0., 0., 0., 1.])) \ No newline at end of file diff --git a/CyXTraX/artist_bridge.py b/CyXTraX/artist_bridge.py index a8a91235363127283803eb210c19361aa04548c0..5bf5ff3910619682770c060c29003ef762fdaab6 100644 --- a/CyXTraX/artist_bridge.py +++ b/CyXTraX/artist_bridge.py @@ -26,10 +26,20 @@ class CylindricalProjection: self.api = api self.x_resolution_mm = np.pi self.y_resolution_mm = np.pi + self.pixel_pitch_mm = 0.139 self.x_px = 2000 self.y_px = 2000 self.detector_radius_mm = 1000. self.objects = list() + self._cylindrical = True + + @property + def cylindrical_mode(self) -> bool: + return self._cylindrical + + @property + def cone_mode(self) -> bool: + return not self._cylindrical def translate(self, position: np.ndarray): # for obj in self.objects: @@ -70,18 +80,24 @@ class CylindricalProjection: self.x_resolution_mm = np.pi self.y_resolution_mm = np.pi self.detector_radius_mm = 1000. - + self._cylindrical = True + print('Care! The default pixel pitch etc is set. If changes are made change the state of this class ...') + def set_cone_mode(self): with importlib.resources.path('CyXTraX.data', 'cone.aRTist') as file_path: print(f'Load: {file_path}') self.api.load_project(file_path) + self.pixel_pitch_mm = 0.139 + self._cylindrical = False + print('Care! The default pixel pitch etc is set. If changes are made change the state of this class ...') +def load_mesh(cylindrical_projection: CylindricalProjection, object_list: list[MeshObject], cylindrical_mode: bool = True) -> bool: + if cylindrical_mode: + cylindrical_projection.set_cylindrical_mode() + else: + cylindrical_projection.set_cone_mode() - -def load_mesh(cylindrical_projection: CylindricalProjection, object_list: list[MeshObject]) -> bool: - cylindrical_projection.set_cylindrical_mode() - for mesh_object in object_list: object_id = cylindrical_projection.api.load_part(mesh_object.object_path) diff --git a/CyXTraX/data/scones.stl b/CyXTraX/data/scones.stl new file mode 100644 index 0000000000000000000000000000000000000000..14338bcb65d181b97b1f5a5e41401fb8d92880e7 Binary files /dev/null and b/CyXTraX/data/scones.stl differ diff --git a/CyXTraX/generate_atlas.py b/CyXTraX/generate_atlas.py index 7506562c7f9406d7002f2dd2b6dc531bdbc3eb71..5c08155dcc9c5988192a83713792fa9432bafc4d 100644 --- a/CyXTraX/generate_atlas.py +++ b/CyXTraX/generate_atlas.py @@ -7,16 +7,16 @@ import json - def to_dict(mesh_object = MeshObject) -> dict: return mesh_object.as_dict() + def record_atlas(cylindrical_projection: CylindricalProjection, mesh_list: list[MeshObject], atlas_name: str, save_folder: Path, map_bounding_box: tuple[float]=(-50., 50.), number_of_maps: int = 4) -> Path: grid = np.linspace(map_bounding_box[0], map_bounding_box[1], number_of_maps) - file_index = len(list(save_folder.glob('*.h5'))) + file_index = len(list(save_folder.glob('*.h5')))+1 save_path = save_folder / f'{file_index:05}_{atlas_name}.h5' file = h5py.File(save_path, 'a') projection = file.require_dataset( diff --git a/CyXTraX/mapping/__init__.py b/CyXTraX/mapping/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05cdb726d7dbc2a5899169dddc14c2db8c1b25e8 --- /dev/null +++ b/CyXTraX/mapping/__init__.py @@ -0,0 +1 @@ +from .mapping import map_geometry_2_projection, map_source_2_cylinder \ No newline at end of file diff --git a/CyXTraX/mapping.py b/CyXTraX/mapping/mapping.py similarity index 95% rename from CyXTraX/mapping.py rename to CyXTraX/mapping/mapping.py index cacc2f13b898f7bdd184a6f9d9e7ab4cc94bb6e5..f697a51bae3125517b9c8627adb535e0b7eda9e2 100644 --- a/CyXTraX/mapping.py +++ b/CyXTraX/mapping/mapping.py @@ -4,6 +4,7 @@ from jax import jit, vmap from jax.scipy.spatial.transform import Rotation + @jit def map_source_2_cylinder(source: jnp.ndarray, maps: jnp.ndarray, map_positions: jnp.ndarray, radius: float = 1000., angle_discretisation: int = 2000, pitch: float = jnp.pi) -> tuple[jnp.ndarray, jnp.ndarray]: @@ -25,7 +26,7 @@ def map_source_2_cylinder(source: jnp.ndarray, maps: jnp.ndarray, map_positions: v = jnp.expand_dims(angles[1], 0) # 1 x m w = jnp.arange(0, maps.shape[2])[jnp.newaxis, ...] #+0.0001 # 1 x m - uvw = jnp.concatenate([v, u, w], 0) # 3 x m + uvw = jnp.concatenate([v, u , w], 0) # 3 x m values = pix.flat_nd_linear_interpolate(maps, uvw) @@ -39,8 +40,8 @@ def cylindrical_angles(source: jnp.ndarray, map_positions: jnp.ndarray, radius: direction_norm_factor = jnp.linalg.norm(directions[:2], axis=0, keepdims=True) directions = directions / direction_norm_factor # m x 3 intersection = directions * radius - sign = jnp.where(intersection[2]>=0, 1, -1) - intersection = intersection * sign + # sign = jnp.where(intersection[2]>=0, 1, -1) + # intersection = intersection * sign z_value = intersection[2] / pitch @@ -105,7 +106,6 @@ def projection_matrix(source: jnp.ndarray, detector: jnp.ndarray, detector_orien rotation_matrix = Rotation.from_quat(detector_orientation).as_matrix() detector_horizontal_vector = -rotation_matrix[0, :, 0] detector_vertical_vector = -rotation_matrix[0, :, 1] - print(detector_horizontal_vector) p3x3 = jnp.vstack([detector_horizontal_vector, detector_vertical_vector, (detector - source).reshape((-1, 3))]).T diff --git a/scripts/01_aRTist_bridge.py b/scripts/01_aRTist_bridge.py index 918f6623ee2b6eed5f1f33f13308a3caa5dc6a2b..db602c54ba40b4782e351649f214a6b02068dd23 100644 --- a/scripts/01_aRTist_bridge.py +++ b/scripts/01_aRTist_bridge.py @@ -1,4 +1,4 @@ -from CyXTraX import CylindricalProjection, load_mesh, olaf_mesh, MeshObject +from CyXTraX import CylindricalProjection, load_mesh, scones_mesh, MeshObject import numpy as np from matplotlib import pyplot as plt @@ -7,20 +7,21 @@ def main(): # Create the cylindrical projection interface cyl_proj = CylindricalProjection() + # Load the test object 'olaf'. It is automatically installed to your site-packages - print(f'STL Path: {olaf_mesh.object_path}') + print(f'STL Path: {scones_mesh.object_path}') # Obect are loaded in a dataclass: MeshObject. - print(f'Olaf is a MeshObject: {isinstance(olaf_mesh, MeshObject)}') - olaf_mesh.position_mm = np.array([42, 69, 0]) - olaf_mesh.orientation_quat = np.array([0.1, -0.1, 0.11, 0.99]) + print(f'Olaf is a MeshObject: {isinstance(scones_mesh, MeshObject)}') + scones_mesh.position_mm = np.array([0, 0, 0]) + scones_mesh.orientation_quat = np.array([0., -0., 0., 0.1]) # The `load_mesh` function load a list of MeshObject - load_mesh(cyl_proj, [olaf_mesh]) + load_mesh(cyl_proj, [scones_mesh]) # Set local position of cylindrical projection - map_position = np.array([1, 2, 4]) + map_position = np.array([0, 0, 0]) # Generate local cylindrical projection half_rays = cyl_proj.compute_projection(map_position, output_full_ray_projection=False) diff --git a/scripts/02_generate_atlas.py b/scripts/02_generate_atlas.py index 9c1cfd3d98d2920b8daef3b7a2883814d20824ec..f7bd52fde4eb7eb28ff88242208e69ef1b627002 100644 --- a/scripts/02_generate_atlas.py +++ b/scripts/02_generate_atlas.py @@ -1,4 +1,4 @@ -from CyXTraX import CylindricalProjection, olaf_mesh, record_atlas +from CyXTraX import CylindricalProjection, olaf_mesh, record_atlas, load_mesh from pathlib import Path @@ -10,10 +10,11 @@ TEMP_FOLDER = FOLDER.parent.parent / 'temp' def main(): # Create the cylindrical projection interface cyl_proj = CylindricalProjection() - + load_mesh(cyl_proj, [olaf_mesh]) # make an atlas + save_path = record_atlas(cyl_proj, [olaf_mesh], 'test', TEMP_FOLDER, - map_bounding_box=[-20, 20], number_of_maps=2) + map_bounding_box=[-10, 20], number_of_maps=2) print(f'Atlas saved to the path: {save_path}') diff --git a/scripts/04_mapping.py b/scripts/04_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..9c512c4a07b8add9966aa2c16319c55a6a233d9d --- /dev/null +++ b/scripts/04_mapping.py @@ -0,0 +1,74 @@ +from CyXTraX import CylindricalProjection, load_atlas, load_mesh +from CyXTraX.mapping import map_geometry_2_projection, map_source_2_cylinder +from pathlib import Path +from jax import numpy as jnp +from jax.scipy.spatial.transform import Rotation +from artistlib import SAVEMODES, utility +from matplotlib import pyplot as plt + + +# Just get the temp folder to store the atlas +FOLDER = Path(__file__) +TEMP_FOLDER = FOLDER.parent.parent / 'temp' + + +def main(): + files = list(TEMP_FOLDER.glob('*.h5')) + if not files: + raise FileNotFoundError('No .h5 files in temp folder. Run Example 02!') + + maps, points, mesh_object_list = load_atlas(files[-1]) + cyl_proj = CylindricalProjection() + + load_mesh(cyl_proj, mesh_object_list, False) + print(f'Cone Beam Mode set: {cyl_proj.cone_mode}') + + rot = 30 + x, y = 2000 * jnp.cos(rot * jnp.pi / 180), 500 * jnp.sin(rot * jnp.pi / 180) + source_position = jnp.array([x, y, 0.]).reshape((1, 3)) + detector_position = jnp.array([-x, -y, -10.]).reshape((1, 3)) + detctor_orientation_scipy = Rotation.from_euler('xyz', [0, -90, rot], degrees=True) + + temp_projection = Path('temp.tif') + detctor_orientation = detctor_orientation_scipy.as_quat().reshape((1, -1)) + + cyl_proj.api.translate('S', source_position[0, 0], source_position[0, 1], source_position[0, 2]) + cyl_proj.api.translate('D', detector_position[0, 0], detector_position[0, 1], detector_position[0, 2]) + cyl_proj.api.rotate_from_quat('D', detctor_orientation[0]) + cyl_proj.api.save_image(temp_projection, save_mode=SAVEMODES.FLOAT_TIFF) + + projection, geometry = utility.load_projection(temp_projection) + + mapped_values, uv_px = map_geometry_2_projection( + source_position, + detector_position, + detctor_orientation, + points, + projection, + cyl_proj.pixel_pitch_mm) + + values_calc, angles_calc = map_source_2_cylinder( + source_position, + maps, + points, + cyl_proj.detector_radius_mm, + cyl_proj.x_px, + cyl_proj.x_resolution_mm) + + for i in range(values_calc.shape[0]): + print(values_calc[i], mapped_values[i]) + map_ = maps[:, :, i] + value = map_[int(angles_calc[0, i]), int(angles_calc[1, i])] + print(value) + + fig = plt.figure() + ax1 = fig.add_subplot(111) + ax1.imshow(projection) + ax1.scatter(uv_px[0], uv_px[1]) + + plt.show() + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index d616638ff63e84bda8432900669f7338e7438a01..f21a4fc832c0cf6fc24ad03313d588888f23255f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,4 +15,12 @@ zip_safe = False CyXTraX = data/cone.aRTist data/cylinder.aRTist - data/olaf_6.stl \ No newline at end of file + data/olaf_6.stl + data/scones.stl + +install_requires = + numpy>=2.0.0 + jaxlib + jax + git+https://github.com/wittlsn/aRTist-PythonLib + scipy \ No newline at end of file