diff --git a/example/convert_projections.py b/example/convert_projections.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f50515298738e533066ac63d5ac7002d304a53 --- /dev/null +++ b/example/convert_projections.py @@ -0,0 +1,25 @@ +from rq_controller.common.io.rq_json import RqJsonWriter +from rq_controller.common.io.thd import RawLoader +from pathlib import Path + + +FOLDER = Path(r"C:\Users\swittl\Downloads\defect_detection\defect_detection\object_2") +SAVE_FOLDER = Path( + r"C:\Users\swittl\Downloads\defect_detection\defect_detection\rq_converted\object_2" +) + + +def main(): + writer = RqJsonWriter() + loader = RawLoader() + + projections_files = FOLDER.glob("*.raw") + + for i, file in enumerate(projections_files): + projection = loader.load_projection(file) + save_path = writer.get_next_projection_save_path(SAVE_FOLDER) + writer.write_projection(save_path, projection, optional=True) + + +if __name__ == "__main__": + main() diff --git a/example/projection_example.py b/example/projection_example.py index 4b6fd05a7bca25675be123af375600b40b09900e..23df07dcb9b88e691c8fa0424212a84d4f1e247b 100644 --- a/example/projection_example.py +++ b/example/projection_example.py @@ -3,7 +3,8 @@ from rq_controller.common.io.rq_json import RqJsonWriter, RqJsonLoader from pathlib import Path -FOLDER = Path('./example/data') +FOLDER = Path("./example/data") + def main(): projection_2 = PyProjection.dummy() @@ -13,12 +14,13 @@ def main(): loader = RqJsonLoader() writer.write_projection(writer.get_projection_save_path_i(FOLDER, 2), projection_2) - + projection_1 = loader.load_projection(writer.get_projection_save_path_i(FOLDER, 1)) print(projection_1) projection_2 = loader.load_projection(writer.get_projection_save_path_i(FOLDER, 2)) print(projection_2) -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/example/tigre_example.py b/example/tigre_example.py index fd7c55138c371b0ca9b33d607002c288eafd7c59..be4aaeb6b6aca43d2fa3872c0cb04f2241f84b70 100644 --- a/example/tigre_example.py +++ b/example/tigre_example.py @@ -5,7 +5,6 @@ import rclpy from scipy.spatial.transform import Rotation import numpy as np -import matplotlib.pyplot as plt # !!! @@ -13,8 +12,9 @@ import matplotlib.pyplot as plt # !!! NUMBER_OF_PROJECTION = 80 -FOD_MM = 1000. -FDD_MM = 2000. +FOD_MM = 1000.0 +FDD_MM = 2000.0 + def main(): # Initialize workflow node @@ -24,10 +24,10 @@ def main(): projection_stack: list[PyProjection] = list() projection = PyProjection.dummy() projection.image = np.zeros((1000, 1000), dtype=np.uint16) - projection.voltage_kv = 200. - projection.current_ua = 10. - projection.detector_heigth_mm = 200. - projection.detector_width_mm = 200. + projection.voltage_kv = 200.0 + projection.current_ua = 10.0 + projection.detector_heigth_mm = 200.0 + projection.detector_width_mm = 200.0 # create the source and detector positions source = np.array([FOD_MM, 0, 0]) @@ -36,21 +36,25 @@ def main(): # Move source / dtector and aquire projections for i in range(NUMBER_OF_PROJECTION): - rotation = Rotation.from_euler('Z', angles[i], False) - scan_pose = projection.look_at(rotation.apply(source) + (np.random.random(3) - 0.5) * 30, - rotation.apply(detector) + (np.random.random(3) - 0.5) * 30, - np.array([0, 0, -1])) + rotation = Rotation.from_euler("Z", angles[i], False) + scan_pose = projection.look_at( + rotation.apply(source) + (np.random.random(3) - 0.5) * 30, + rotation.apply(detector) + (np.random.random(3) - 0.5) * 30, + np.array([0, 0, -1]), + ) projection_stack.append(workflow.aquire_projection(scan_pose)) # Define reconstruction area and call reconstruction client - roi = PyRegionOfIntrest(center_points_mm=np.array([0., 0., 0.]), - dimensions_mm=np.array([120., 120., 120.]), - resolution_mm=np.array([0.5, 0.5, 0.5])) + roi = PyRegionOfIntrest( + center_points_mm=np.array([0.0, 0.0, 0.0]), + dimensions_mm=np.array([120.0, 120.0, 120.0]), + resolution_mm=np.array([0.5, 0.5, 0.5]), + ) + + workflow.reconstruction.set_reconstruction_algorithm_name("ossart") # ossart / fdk + workflow.get_volume(projection_stack, roi) - workflow.reconstruction.set_reconstruction_algorithm_name('ossart') # ossart / fdk - volume = workflow.get_volume(projection_stack, roi) - -if __name__ == '__main__': +if __name__ == "__main__": rclpy.init() - main() \ No newline at end of file + main() diff --git a/example/volume_example.py b/example/volume_example.py index b462199a24de709f5a91d50a8e4fdf05bd5cddb8..68d515ce427c26087570f2e6cfe439e8f926c71e 100644 --- a/example/volume_example.py +++ b/example/volume_example.py @@ -4,34 +4,33 @@ from pathlib import Path import numpy as np -FOLDER = Path('./example/data') +FOLDER = Path("./example/data") def main(): volume = PyVolume.dummy() - print(f'Shape (x / y / z): {volume.shape}') + print(f"Shape (x / y / z): {volume.shape}") msg = volume.as_message() - print(f'Number of slices: {len(msg.slices)}') + print(f"Number of slices: {len(msg.slices)}") volume_2 = PyVolume.from_message(msg) - print(f'Shape (x / y / z): {volume_2.shape}') + print(f"Shape (x / y / z): {volume_2.shape}") writer = RqJsonWriter() writer.write_volume(writer.get_volume_save_path_i(FOLDER, 1), volume_2) loader = RqJsonLoader() volume_3 = loader.load_volume(writer.get_volume_save_path_i(FOLDER, 1)) - print(f'Shape (x / y / z): {volume_3.shape}') + print(f"Shape (x / y / z): {volume_3.shape}") grid = volume_3.roi.get_grid() - print(f'grid shape: {grid.shape}') + print(f"grid shape: {grid.shape}") - pos = volume_3.roi.next_neighbor(grid, np.array([0., 0., 0.])) + pos = volume_3.roi.next_neighbor(grid, np.array([0.0, 0.0, 0.0])) print(pos) value = volume_3.array[pos[0], pos[1], pos[2]] print(value) -if __name__ == '__main__': +if __name__ == "__main__": main() - diff --git a/rq_controller/common/__init__.py b/rq_controller/common/__init__.py index 49190b8087a1085ee69dbe334d271973097b6935..c0676a102ffa3101692d09e467acdaaa60855f6c 100644 --- a/rq_controller/common/__init__.py +++ b/rq_controller/common/__init__.py @@ -1,2 +1,4 @@ from .projection import PyProjectionGeometry, PyProjection -from .volume import PyVolume, PyRegionOfIntrest \ No newline at end of file +from .volume import PyVolume, PyRegionOfIntrest + +__all__ = ["PyProjectionGeometry", "PyProjection", "PyVolume", "PyRegionOfIntrest"] diff --git a/rq_controller/common/io/loader.py b/rq_controller/common/io/loader.py index 4ce8e32f6ed7a28bc59d90c75cf0c9d6db5f219e..66c7b5c3578f9237dfb33c3dbd034d17be396db3 100644 --- a/rq_controller/common/io/loader.py +++ b/rq_controller/common/io/loader.py @@ -1,16 +1,15 @@ -import numpy as np -import json from pathlib import Path from ...common import PyProjectionGeometry, PyProjection, PyRegionOfIntrest, PyVolume -class BaseDataLoader(): - def __init__(self, - porjection_geometry_suffix: str, - projection_suffix: str, - region_of_intrest_suffix: str, - volume_suffix: str): - +class BaseDataLoader: + def __init__( + self, + porjection_geometry_suffix: str, + projection_suffix: str, + region_of_intrest_suffix: str, + volume_suffix: str, + ): self.porjection_geometry_suffix = porjection_geometry_suffix self.projection_suffix = projection_suffix self.region_of_intrest_suffix = region_of_intrest_suffix @@ -18,13 +17,12 @@ class BaseDataLoader(): def load_projection_geometry(self, load_path: Path) -> PyProjectionGeometry: raise NotImplementedError() - + def load_projection(self, load_path: Path) -> PyProjection: raise NotImplementedError() - + def load_region_of_intrest(self, load_path: Path) -> PyRegionOfIntrest: raise NotImplementedError() - + def load_volume(self, load_path: Path) -> PyVolume: raise NotImplementedError() - \ No newline at end of file diff --git a/rq_controller/common/io/rq_json/__init__.py b/rq_controller/common/io/rq_json/__init__.py index f6f581260eb0e287f5521a054bd3cddd01771b2b..4d08736126b917997610aa7a00378fec6e8b76d0 100644 --- a/rq_controller/common/io/rq_json/__init__.py +++ b/rq_controller/common/io/rq_json/__init__.py @@ -1,2 +1,4 @@ from .json_load import RqJsonLoader -from .json_write import RqJsonWriter \ No newline at end of file +from .json_write import RqJsonWriter + +__all__ = ["RqJsonLoader", "RqJsonWriter"] diff --git a/rq_controller/common/io/rq_json/json_load.py b/rq_controller/common/io/rq_json/json_load.py index 2a9d6e6d2654418ca9d926703a56992b46b21300..0bd1a5de5f65814da9ce79b3e38eeb977ad06e46 100644 --- a/rq_controller/common/io/rq_json/json_load.py +++ b/rq_controller/common/io/rq_json/json_load.py @@ -1,18 +1,24 @@ import numpy as np -import json +import json from pathlib import Path -from ..loader import BaseDataLoader, PyProjection, PyProjectionGeometry, PyRegionOfIntrest, PyVolume +from ..loader import ( + BaseDataLoader, + PyProjection, + PyProjectionGeometry, + PyRegionOfIntrest, + PyVolume, +) from PIL import Image import pyometiff class RqJsonLoader(BaseDataLoader): def __init__(self): - super().__init__('.geom-json', '.tif', '.roi-json', '.ome.tiff') + super().__init__(".geom-json", ".tif", ".roi-json", ".ome.tiff") def load_json(self, load_path: Path) -> dict: - with open(str(load_path), 'r') as f: + with open(str(load_path), "r") as f: data_dict = json.load(f) return data_dict @@ -20,57 +26,71 @@ class RqJsonLoader(BaseDataLoader): data_dict = self.load_json(load_path) projection_geometry = PyProjectionGeometry( - focal_spot_mm=np.array([data_dict['focal_spot_mm']]), - detector_postion_mm=np.array([data_dict['detector_postion_mm']]), - detector_orientation_quad=np.array([data_dict['detector_orientation_quad']]), - frame_id=data_dict['frame_id'], - focal_spot_orientation_quad=np.array(data_dict['focal_spot_orientation_quad'])) - + focal_spot_mm=np.array([data_dict["focal_spot_mm"]]), + detector_postion_mm=np.array([data_dict["detector_postion_mm"]]), + detector_orientation_quad=np.array( + [data_dict["detector_orientation_quad"]] + ), + frame_id=data_dict["frame_id"], + focal_spot_orientation_quad=np.array( + data_dict["focal_spot_orientation_quad"] + ), + ) + return projection_geometry - + def load_projection(self, load_path: Path) -> PyProjection: - load_path_projection_geometry = load_path.parent / f'{load_path.stem}{self.porjection_geometry_suffix}' + load_path_projection_geometry = ( + load_path.parent / f"{load_path.stem}{self.porjection_geometry_suffix}" + ) data_dict = self.load_json(load_path_projection_geometry) image = np.asarray(Image.open(load_path)) projection = PyProjection( - focal_spot_mm=np.array([data_dict['focal_spot_mm']]), - detector_postion_mm=np.array([data_dict['detector_postion_mm']]), - detector_orientation_quad=np.array([data_dict['detector_orientation_quad']]), + focal_spot_mm=np.array([data_dict["focal_spot_mm"]]), + detector_postion_mm=np.array([data_dict["detector_postion_mm"]]), + detector_orientation_quad=np.array( + [data_dict["detector_orientation_quad"]] + ), image=image, - detector_heigth_mm=data_dict['detector_heigth_mm'], - detector_width_mm=data_dict['detector_width_mm'], - frame_id=data_dict['frame_id'], - focal_spot_orientation_quad=np.array(data_dict['focal_spot_orientation_quad']), - voltage_kv=data_dict['voltage_kv'], - current_ua=data_dict['current_ua'], - exposure_time_ms=data_dict['exposure_time_ms']) - + detector_heigth_mm=data_dict["detector_heigth_mm"], + detector_width_mm=data_dict["detector_width_mm"], + frame_id=data_dict["frame_id"], + focal_spot_orientation_quad=np.array( + data_dict["focal_spot_orientation_quad"] + ), + voltage_kv=data_dict["voltage_kv"], + current_ua=data_dict["current_ua"], + exposure_time_ms=data_dict["exposure_time_ms"], + ) + return projection - + def load_region_of_intrest(self, load_path: Path) -> PyRegionOfIntrest: data_dict = self.load_json(load_path) region_of_intrest = PyRegionOfIntrest( - center_points_mm=np.array(data_dict['center_points_mm']), - dimensions_mm=np.array(data_dict['dimensions_mm']), - frame_id=data_dict['frame_id'], - resolution_mm=np.array(data_dict['resolution_mm'])) - + center_points_mm=np.array(data_dict["center_points_mm"]), + dimensions_mm=np.array(data_dict["dimensions_mm"]), + frame_id=data_dict["frame_id"], + resolution_mm=np.array(data_dict["resolution_mm"]), + ) + return region_of_intrest - + def load_volume(self, load_path: Path) -> PyVolume: - load_path_roi = load_path.parent / f'{load_path.stem}{self.region_of_intrest_suffix}' + load_path_roi = ( + load_path.parent / f"{load_path.stem}{self.region_of_intrest_suffix}" + ) roi = self.load_region_of_intrest(load_path_roi) reader = pyometiff.OMETIFFReader(load_path) volume, _, _ = reader.read() - + if volume.dtype == np.uint16: data_type = 0 elif volume.dtype == np.uint8: data_type = 1 else: - raise ValueError('data type not implemented.') - - return PyVolume(volume, roi, data_type) + raise ValueError("data type not implemented.") + return PyVolume(volume, roi, data_type) diff --git a/rq_controller/common/io/rq_json/json_write.py b/rq_controller/common/io/rq_json/json_write.py index e5d6de83f833b2005204cb25aee5b08a9acd43cb..0d14c0dd449aca05d62798c1c453e8f3bb4e4d54 100644 --- a/rq_controller/common/io/rq_json/json_write.py +++ b/rq_controller/common/io/rq_json/json_write.py @@ -1,73 +1,104 @@ -import numpy as np -import json +import json from pathlib import Path -from ..writer import BaseDataWriter, PyProjection, PyProjectionGeometry, PyRegionOfIntrest, PyVolume +from ..writer import ( + BaseDataWriter, + PyProjection, + PyProjectionGeometry, + PyRegionOfIntrest, + PyVolume, +) from PIL import Image import pyometiff class RqJsonWriter(BaseDataWriter): def __init__(self): - super().__init__('.geom-json', '.tif', '.roi-json', '.ome.tiff') + super().__init__(".geom-json", ".tif", ".roi-json", ".ome.tiff") def write_json(self, save_path: Path, data_dict: dict): - with open(str(save_path), 'w') as f: + with open(str(save_path), "w") as f: json.dump(data_dict, f, indent=2) - def write_projection_geometry(self, save_path: Path, projection_geometry: PyProjectionGeometry): + def write_projection_geometry( + self, save_path: Path, projection_geometry: PyProjectionGeometry + ): data_dict = dict() - #geometry - data_dict['focal_spot_mm'] = projection_geometry.focal_spot_mm.tolist() - data_dict['detector_postion_mm'] = projection_geometry.detector_postion_mm.tolist() - data_dict['detector_orientation_quad'] = projection_geometry.detector_orientation_quad.tolist() - data_dict['focal_spot_orientation_quad'] = projection_geometry.focal_spot_orientation_quad.tolist() - - data_dict['frame_id'] = projection_geometry.frame_id + # geometry + data_dict["focal_spot_mm"] = projection_geometry.focal_spot_mm.tolist() + data_dict["detector_postion_mm"] = ( + projection_geometry.detector_postion_mm.tolist() + ) + data_dict["detector_orientation_quad"] = ( + projection_geometry.detector_orientation_quad.tolist() + ) + data_dict["focal_spot_orientation_quad"] = ( + projection_geometry.focal_spot_orientation_quad.tolist() + ) + + data_dict["frame_id"] = projection_geometry.frame_id self.write_json(save_path, data_dict) - def write_projection(self, save_path: Path, projection: PyProjection): - save_path_projection_geometry = save_path.parent / f'{save_path.stem}{self.porjection_geometry_suffix}' + def write_projection( + self, save_path: Path, projection: PyProjection, optional: bool = False + ): + save_path_projection_geometry = ( + save_path.parent / f"{save_path.stem}{self.porjection_geometry_suffix}" + ) data_dict = dict() # geometry - data_dict['focal_spot_mm'] = projection.focal_spot_mm.tolist() - data_dict['detector_postion_mm'] = projection.detector_postion_mm.tolist() - data_dict['detector_orientation_quad'] = projection.detector_orientation_quad.tolist() - data_dict['focal_spot_orientation_quad'] = projection.focal_spot_orientation_quad.tolist() - + data_dict["focal_spot_mm"] = projection.focal_spot_mm.tolist() + data_dict["detector_postion_mm"] = projection.detector_postion_mm.tolist() + data_dict["detector_orientation_quad"] = ( + projection.detector_orientation_quad.tolist() + ) + data_dict["focal_spot_orientation_quad"] = ( + projection.focal_spot_orientation_quad.tolist() + ) + # detector - data_dict['pixel_pitch_x_mm'] = projection.pixel_pitch_x_mm - data_dict['pixel_pitch_y_mm'] = projection.pixel_pitch_y_mm - data_dict['detector_heigth_mm'] = projection.detector_heigth_mm - data_dict['detector_width_mm'] = projection.detector_width_mm + data_dict["pixel_pitch_x_mm"] = projection.pixel_pitch_x_mm + data_dict["pixel_pitch_y_mm"] = projection.pixel_pitch_y_mm + data_dict["detector_heigth_mm"] = projection.detector_heigth_mm + data_dict["detector_width_mm"] = projection.detector_width_mm # xray - data_dict['exposure_time_ms'] = projection.exposure_time_ms - data_dict['voltage_kv'] = projection.voltage_kv - data_dict['current_ua'] = projection.current_ua - data_dict['frame_id'] = projection.frame_id + data_dict["exposure_time_ms"] = projection.exposure_time_ms + data_dict["voltage_kv"] = projection.voltage_kv + data_dict["current_ua"] = projection.current_ua + data_dict["frame_id"] = projection.frame_id + + # extra / converted information + if optional: + data_dict["optional"] = dict( + projection_matrix_cera=projection.projection_matrix_cera_px.tolist() + ) self.write_json(save_path_projection_geometry, data_dict) image = Image.fromarray(projection.image) image.save(save_path) - def write_region_of_intrest(self, save_path: Path, region_of_intrest: PyRegionOfIntrest): + def write_region_of_intrest( + self, save_path: Path, region_of_intrest: PyRegionOfIntrest + ): data_dict = dict() # geometry - data_dict['center_points_mm'] = region_of_intrest.center_points_mm.tolist() - data_dict['dimensions_mm'] = region_of_intrest.dimensions_mm.tolist() - data_dict['frame_id'] = region_of_intrest.frame_id - data_dict['resolution_mm'] = region_of_intrest.resolution_mm.tolist() + data_dict["center_points_mm"] = region_of_intrest.center_points_mm.tolist() + data_dict["dimensions_mm"] = region_of_intrest.dimensions_mm.tolist() + data_dict["frame_id"] = region_of_intrest.frame_id + data_dict["resolution_mm"] = region_of_intrest.resolution_mm.tolist() self.write_json(save_path, data_dict) def write_volume(self, save_path: Path, volume: PyVolume): - save_path_region_of_intrest = save_path.parent / f'{save_path.stem}{self.region_of_intrest_suffix}' + save_path_region_of_intrest = ( + save_path.parent / f"{save_path.stem}{self.region_of_intrest_suffix}" + ) self.write_region_of_intrest(save_path_region_of_intrest, volume.roi) writer = pyometiff.OMETIFFWriter(save_path, volume.array, dict()) - writer.write() \ No newline at end of file + writer.write() diff --git a/rq_controller/common/io/thd/__init__.py b/rq_controller/common/io/thd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84f08e868bb300fe2373d6bc3488f4333fb285b8 --- /dev/null +++ b/rq_controller/common/io/thd/__init__.py @@ -0,0 +1,5 @@ +from .thd_load import RawLoader +from .thd_write import RawWriter + + +__all__ = ["RawLoader", "RawWriter"] diff --git a/rq_controller/common/io/thd/thd_load.py b/rq_controller/common/io/thd/thd_load.py new file mode 100644 index 0000000000000000000000000000000000000000..257d6cb0cea755ea24fd1524f452f8cbdff3860e --- /dev/null +++ b/rq_controller/common/io/thd/thd_load.py @@ -0,0 +1,125 @@ +import numpy as np +import json +from pathlib import Path + +try: + from PythonTools.raw2py import raw2py + from PythonTools.ezrt_header import EzrtHeader + from PythonTools.rek2py import rek2py +except ModuleNotFoundError: + raise ModuleNotFoundError("Install PythonTools from Fraunhofer EZRT.") + +from ..loader import ( + BaseDataLoader, + PyProjection, + PyProjectionGeometry, + PyRegionOfIntrest, + PyVolume, +) +from scipy.spatial.transform import Rotation + +from ..rq_json.json_load import RqJsonLoader + + +class RawLoader(BaseDataLoader): + def __init__(self): + super().__init__(".raw", ".raw", ".roi-json", ".rek") + + def load_projection_geometry(self, load_path: Path) -> PyProjectionGeometry: + header = EzrtHeader.fromfile(load_path) + + focal_spot_mm = np.array(header.agv_source_position) * 1000.0 + detector_postion_mm = np.array(header.agv_detector_center_position) * 1000.0 + + line = np.array(header.agv_detector_line_direction) + column = np.array(header.agv_detector_col_direction) + normal = np.cross(line, column) + + matrix = np.eye(3) + matrix[:, 0] = line + matrix[:, 1] = column + matrix[:, 2] = normal + + rotation = Rotation.from_matrix(matrix) + detector_orientation_quad = rotation.as_quat() + + frame_id = header.number_of_images + + projection_geometry = PyProjectionGeometry( + focal_spot_mm, + detector_postion_mm, + detector_orientation_quad, + frame_id, + focal_spot_orientation_quad=np.array([0.0, 0.0, 0.0, 1.0]), + ) + + return projection_geometry + + def load_projection( + self, load_path: Path, switch_order: bool = True + ) -> PyProjection: + header, image = raw2py(load_path, switch_order=switch_order) + + focal_spot_mm = np.array(header.agv_source_position) * 1000.0 + detector_postion_mm = np.array(header.agv_detector_center_position) * 1000.0 + + line = np.array(header.agv_detector_line_direction) + column = np.array(header.agv_detector_col_direction) + normal = np.cross(line, column) + + matrix = np.eye(3) + matrix[:, 0] = line + matrix[:, 1] = column + matrix[:, 2] = normal + + rotation = Rotation.from_matrix(matrix) + detector_orientation_quad = rotation.as_quat() + + frame_id = header.number_of_images + + detector_heigth_mm = header.pixel_width_in_um / 1000.0 * image.shape[1] + detector_width_mm = header.pixel_width_in_um / 1000.0 * image.shape[0] + + voltage_kv = header.voltage_in_kv + current_ua = header.current_in_ua + exposure_time_ms = header.exposure_time_in_ms + + projection = PyProjection( + focal_spot_mm, + detector_postion_mm, + detector_orientation_quad, + image, + detector_heigth_mm, + detector_width_mm, + frame_id, + np.array([0.0, 0.0, 0.0, 1.0]), + voltage_kv, + current_ua, + exposure_time_ms, + ) + + return projection + + def load_json(self, load_path: Path) -> dict: + with open(str(load_path), "r") as f: + data_dict = json.load(f) + return data_dict + + def load_region_of_intrest(self, load_path: Path) -> PyRegionOfIntrest: + RqJsonLoader.load_region_of_intrest(self, load_path) + + def load_volume(self, load_path: Path) -> PyVolume: + header, volume = rek2py(load_path) + load_path_roi = ( + load_path.parent / f"{load_path.stem}{self.region_of_intrest_suffix}" + ) + roi = self.load_region_of_intrest(load_path_roi) + + if volume.dtype == np.uint16: + data_type = 0 + elif volume.dtype == np.uint8: + data_type = 1 + else: + raise ValueError("data type not implemented.") + + return PyVolume(volume, roi, data_type) diff --git a/rq_controller/common/io/thd/thd_write.py b/rq_controller/common/io/thd/thd_write.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0f3bf893ecf21d2bcc7953eb3e897b79083804 --- /dev/null +++ b/rq_controller/common/io/thd/thd_write.py @@ -0,0 +1,86 @@ +from pathlib import Path + +from ..writer import ( + BaseDataWriter, + PyProjection, + PyProjectionGeometry, + PyRegionOfIntrest, + PyVolume, +) + +try: + from PythonTools.py2raw import py2raw + from PythonTools.ezrt_header import EzrtHeader + from PythonTools.py2rek import py2rek +except ModuleNotFoundError: + raise ModuleNotFoundError("Install PythonTools from Fraunhofer EZRT.") + + +from ..rq_json.json_write import RqJsonWriter + + +class RawWriter(BaseDataWriter): + def __init__(self): + super().__init__(".raw", ".raw", ".roi-json", ".rek") + + def write_projection_geometry( + self, save_path: Path, projection_geometry: PyProjectionGeometry + ): + header = EzrtHeader() + header.agv_source_position = ( + projection_geometry.focal_spot_mm / 1000.0 + ).tolist() + header.agv_detector_center_position = ( + projection_geometry.detector_postion_mm / 1000.0 + ).tolist() + + line = projection_geometry.detector_rotation_matrix[:, 0] + column = projection_geometry.detector_rotation_matrix[:, 0] + + header.agv_detector_line_direction = line + header.agv_detector_col_direction = column + + header.number_of_images = projection_geometry.frame_id + header.tofile(save_path) + + def write_projection( + self, save_path: Path, projection: PyProjection, swith_order: bool = False + ): + projection_geometry = projection.get_projection_geometry() + + header = EzrtHeader() + + header.agv_source_position = ( + projection_geometry.focal_spot_mm / 1000.0 + ).tolist() + header.agv_detector_center_position = ( + projection_geometry.detector_postion_mm / 1000.0 + ).tolist() + + line = projection_geometry.detector_rotation_matrix[:, 0] + column = projection_geometry.detector_rotation_matrix[:, 0] + + header.agv_detector_line_direction = line + header.agv_detector_col_direction = column + + header.number_of_images = projection_geometry.frame_id + header.pixel_width_in_um = projection.pixel_pitch_x_mm * 1000.0 + header.voltage_in_kv = projection.voltage_kv + header.current_in_ua = projection.current_ua + header.exposure_time_in_ms = projection.exposure_time_ms + + py2raw(projection.image, save_path, header, swith_order) + + def write_region_of_intrest( + self, save_path: Path, region_of_intrest: PyRegionOfIntrest + ): + RqJsonWriter.write_region_of_intrest(self, save_path, region_of_intrest) + + def write_volume(self, save_path: Path, volume: PyVolume): + header = EzrtHeader() + header.num_voxel_x = volume.shape[0] + header.num_voxel_y = volume.shape[1] + header.num_voxel_z = volume.shape[2] + header.pixel_width_in_um = volume.roi.resolution_mm[0] + + py2rek(volume.array, save_path, header, True) diff --git a/rq_controller/common/io/writer.py b/rq_controller/common/io/writer.py index f9475c382156f875f56cb9093a9e83a24eab8c25..91ae74442dfcecc4cdb88a05c9a820550d788dfb 100644 --- a/rq_controller/common/io/writer.py +++ b/rq_controller/common/io/writer.py @@ -1,70 +1,101 @@ -import numpy as np -import json from pathlib import Path from ...common import PyProjectionGeometry, PyProjection, PyRegionOfIntrest, PyVolume -class BaseDataWriter(): - projection_name: str = 'projection' - projection_geometry_name: str ='geometry' - region_of_intrest_name: str = 'roi' - volume_name: str = 'volume' +class BaseDataWriter: + projection_name: str = "projection" + projection_geometry_name: str = "geometry" + region_of_intrest_name: str = "roi" + volume_name: str = "volume" - def __init__(self, - porjection_geometry_suffix: str, - projection_suffix: str, - region_of_intrest_suffix: str, - volume_suffix: str): - + def __init__( + self, + porjection_geometry_suffix: str, + projection_suffix: str, + region_of_intrest_suffix: str, + volume_suffix: str, + ): self.porjection_geometry_suffix = porjection_geometry_suffix self.projection_suffix = projection_suffix self.region_of_intrest_suffix = region_of_intrest_suffix self.volume_suffix = volume_suffix - def write_projection_geometry(self, save_path: Path, projection_geometry: PyProjectionGeometry): + def write_projection_geometry( + self, save_path: Path, projection_geometry: PyProjectionGeometry + ): raise NotImplementedError() - + def write_projection(self, save_path: Path, projection: PyProjection): raise NotImplementedError() - - def write_region_of_intrest(self, save_path: Path, region_of_intrest: PyRegionOfIntrest): + + def write_region_of_intrest( + self, save_path: Path, region_of_intrest: PyRegionOfIntrest + ): raise NotImplementedError - + def write_volume(self, save_path: Path, volume: PyVolume): raise NotImplementedError() - + def get_next_projection_save_path(self, save_folder: Path) -> Path: - return self.get_projection_save_path_i(save_folder, self.number_of_projections(save_folder) + 1) + return self.get_projection_save_path_i( + save_folder, self.number_of_projections(save_folder) + 1 + ) def get_projection_save_path_i(self, save_folder: Path, i) -> Path: - return save_folder / f'{self.projection_name}_{i:05}{self.projection_suffix}' - + return save_folder / f"{self.projection_name}_{i:05}{self.projection_suffix}" + def get_next_projection_geometry_save_path(self, save_folder: Path) -> Path: - return self.get_projection_geometry_save_path_i(save_folder, self.number_of_projection_geometries(save_folder) + 1) + return self.get_projection_geometry_save_path_i( + save_folder, self.number_of_projection_geometries(save_folder) + 1 + ) def get_projection_geometry_save_path_i(self, save_folder: Path, i) -> Path: - return save_folder / f'{self.projection_geometry_name}_{i:05}{self.porjection_geometry_suffix}' - + return ( + save_folder + / f"{self.projection_geometry_name}_{i:05}{self.porjection_geometry_suffix}" + ) + def get_next_region_of_intrest_save_path(self, save_folder: Path) -> Path: - return self.get_region_of_intrest_save_path_i(save_folder, self.number_of_region_of_intrests(save_folder) + 1) + return self.get_region_of_intrest_save_path_i( + save_folder, self.number_of_region_of_intrests(save_folder) + 1 + ) def get_region_of_intrest_save_path_i(self, save_folder: Path, i) -> Path: - return save_folder / f'{self.region_of_intrest_name}_{i:05}{self.region_of_intrest_suffix}' - + return ( + save_folder + / f"{self.region_of_intrest_name}_{i:05}{self.region_of_intrest_suffix}" + ) + def get_next_volume_save_path(self, save_folder: Path) -> Path: - return self.get_volume_save_path_i(save_folder, self.number_of_volumes(save_folder) + 1) + return self.get_volume_save_path_i( + save_folder, self.number_of_volumes(save_folder) + 1 + ) def get_volume_save_path_i(self, save_folder: Path, i) -> Path: - return save_folder / f'{self.volume_name}_{i:05}{self.volume_suffix}' - + return save_folder / f"{self.volume_name}_{i:05}{self.volume_suffix}" + def number_of_projection_geometries(self, folder: Path) -> int: - return len(list(folder.glob(f'{self.projection_geometry_name}*{self.porjection_geometry_suffix}'))) - + return len( + list( + folder.glob( + f"{self.projection_geometry_name}*{self.porjection_geometry_suffix}" + ) + ) + ) + def number_of_projections(self, folder: Path) -> int: - return len(list(folder.glob(f'{self.projection_name}*{self.projection_suffix}'))) - + return len( + list(folder.glob(f"{self.projection_name}*{self.projection_suffix}")) + ) + def number_of_region_of_intrests(self, folder: Path) -> int: - return len(list(folder.glob(f'{self.region_of_intrest_name}*{self.region_of_intrest_suffix}'))) - + return len( + list( + folder.glob( + f"{self.region_of_intrest_name}*{self.region_of_intrest_suffix}" + ) + ) + ) + def number_of_volumes(self, folder: Path) -> int: - return len(list(folder.glob(f'{self.volume_name}*{self.volume_suffix}'))) + return len(list(folder.glob(f"{self.volume_name}*{self.volume_suffix}"))) diff --git a/rq_controller/common/projection.py b/rq_controller/common/projection.py index d69ab60e46787ff3f8b14cb9211ccc10c59a5703..2f8150980db130e1f7a21cfac44ee969f5c9dd75 100644 --- a/rq_controller/common/projection.py +++ b/rq_controller/common/projection.py @@ -24,10 +24,20 @@ class PyProjection(PyProjectionGeometry): exposure_time_ms (float): The exposure time in milliseconds. """ - def __init__(self, focal_spot_mm: ndarray, detector_postion_mm: ndarray, detector_orientation_quad: ndarray, image: np.ndarray, - detector_heigth_mm: float, detector_width_mm: float, frame_id: str = 'object', - focal_spot_orientation_quad: np.ndarray = np.array([0., 0., 0, 1.]), - voltage_kv: float = 100., current_ua: float = 100., exposure_time_ms: float = 1000.) -> None: + def __init__( + self, + focal_spot_mm: ndarray, + detector_postion_mm: ndarray, + detector_orientation_quad: ndarray, + image: np.ndarray, + detector_heigth_mm: float, + detector_width_mm: float, + frame_id: str = "object", + focal_spot_orientation_quad: np.ndarray = np.array([0.0, 0.0, 0, 1.0]), + voltage_kv: float = 100.0, + current_ua: float = 100.0, + exposure_time_ms: float = 1000.0, + ) -> None: """ Initializes a PyProjection instance. @@ -44,8 +54,14 @@ class PyProjection(PyProjectionGeometry): current_ua (float): The current in microamperes. Default is 100. exposure_time_ms (float): The exposure time in milliseconds. Default is 1000. """ - - super().__init__(focal_spot_mm, detector_postion_mm, detector_orientation_quad, frame_id, focal_spot_orientation_quad) + + super().__init__( + focal_spot_mm, + detector_postion_mm, + detector_orientation_quad, + frame_id, + focal_spot_orientation_quad, + ) self.image = image.astype(np.uint16) self.detector_heigth_mm = detector_heigth_mm self.detector_width_mm = detector_width_mm @@ -62,19 +78,31 @@ class PyProjection(PyProjectionGeometry): PyProjection: A dummy instance with default values. """ - return cls(np.array([0., 100., 0]), - np.array([0., -100., 0]), - np.array([1., 0., 0, 1.]), - np.random.randint(0, 65535, size=(10, 10), dtype=np.uint16), - 10., 10.) - + return cls( + np.array([0.0, 100.0, 0]), + np.array([0.0, -100.0, 0]), + np.array([1.0, 0.0, 0, 1.0]), + np.random.randint(0, 65535, size=(10, 10), dtype=np.uint16), + 10.0, + 10.0, + ) + @classmethod - def from_look_at(cls, focal_spot_mm: ndarray, detector_postion_mm: ndarray, image_shape: tuple[int], - detector_heigth_mm: float, detector_width_mm: float, frame_id: str = 'object', - voltage_kv: float = 100., current_ua: float = 100., exposure_time_ms: float = 1000., - up_vector: np.ndarray = np.array([0., 0., 1.])) -> 'PyProjection': + def from_look_at( + cls, + focal_spot_mm: ndarray, + detector_postion_mm: ndarray, + image_shape: tuple[int], + detector_heigth_mm: float, + detector_width_mm: float, + frame_id: str = "object", + voltage_kv: float = 100.0, + current_ua: float = 100.0, + exposure_time_ms: float = 1000.0, + up_vector: np.ndarray = np.array([0.0, 0.0, 1.0]), + ) -> "PyProjection": image = np.zeros(image_shape, np.uint16) - + vector = focal_spot_mm - detector_postion_mm vector = vector / np.linalg.norm(vector) up_vector = up_vector / np.linalg.norm(up_vector) @@ -82,47 +110,78 @@ class PyProjection(PyProjectionGeometry): detector_orientation = np.eye(3) detector_orientation[:, 2] = vector detector_orientation[:, 1] = np.cross(up_vector, vector) - detector_orientation[:, 0] = np.cross(detector_orientation[:, 1], detector_orientation[:, 2]) + detector_orientation[:, 0] = np.cross( + detector_orientation[:, 1], detector_orientation[:, 2] + ) source_orientation = np.eye(3) source_orientation[:, 2] = -vector source_orientation[:, 1] = np.cross(up_vector, -vector) - source_orientation[:, 0] = np.cross(source_orientation[:, 1], source_orientation[:, 2]) + source_orientation[:, 0] = np.cross( + source_orientation[:, 1], source_orientation[:, 2] + ) return cls( - focal_spot_mm, detector_postion_mm, + focal_spot_mm, + detector_postion_mm, Rotation.from_matrix(detector_orientation).as_quat(), - image, detector_heigth_mm, detector_width_mm, - frame_id, Rotation.from_matrix(source_orientation).as_quat(), - voltage_kv, current_ua, exposure_time_ms) - - def look_at(self, focal_spot_mm: ndarray, detector_postion_mm: ndarray, up_vector: np.ndarray = np.array([0., 0., 1.])) -> PyProjection: + image, + detector_heigth_mm, + detector_width_mm, + frame_id, + Rotation.from_matrix(source_orientation).as_quat(), + voltage_kv, + current_ua, + exposure_time_ms, + ) + + def look_at( + self, + focal_spot_mm: ndarray, + detector_postion_mm: ndarray, + up_vector: np.ndarray = np.array([0.0, 0.0, 1.0]), + ) -> PyProjection: vector = focal_spot_mm - detector_postion_mm vector = vector / np.linalg.norm(vector) up_vector = up_vector / np.linalg.norm(up_vector) - z_rot = Rotation.from_euler('Z', -90, True) + z_rot = Rotation.from_euler("Z", -90, True) detector_orientation = np.eye(3) detector_orientation[:, 2] = vector detector_orientation[:, 1] = np.cross(up_vector, vector) - detector_orientation[:, 0] = np.cross(detector_orientation[:, 1], detector_orientation[:, 2]) + detector_orientation[:, 0] = np.cross( + detector_orientation[:, 1], detector_orientation[:, 2] + ) - detector_orientation_quad = (Rotation.from_matrix(detector_orientation) * z_rot).as_quat() + detector_orientation_quad = ( + Rotation.from_matrix(detector_orientation) * z_rot + ).as_quat() source_orientation = np.eye(3) source_orientation[:, 2] = -vector source_orientation[:, 1] = np.cross(up_vector, -vector) - source_orientation[:, 0] = np.cross(source_orientation[:, 1], source_orientation[:, 2]) - source_orientation_quad = (Rotation.from_matrix(source_orientation) * z_rot).as_quat() - + source_orientation[:, 0] = np.cross( + source_orientation[:, 1], source_orientation[:, 2] + ) + source_orientation_quad = ( + Rotation.from_matrix(source_orientation) * z_rot + ).as_quat() + return PyProjection( - focal_spot_mm, detector_postion_mm, detector_orientation_quad, self.image, - self.detector_heigth_mm, self.detector_width_mm, - self.frame_id, source_orientation_quad, self.voltage_kv, - self.current_ua, self.exposure_time_ms) + focal_spot_mm, + detector_postion_mm, + detector_orientation_quad, + self.image, + self.detector_heigth_mm, + self.detector_width_mm, + self.frame_id, + source_orientation_quad, + self.voltage_kv, + self.current_ua, + self.exposure_time_ms, + ) - def __str__(self) -> str: """ Returns a string representation of the PyProjection instance. @@ -130,14 +189,18 @@ class PyProjection(PyProjectionGeometry): Returns: str: The string representation of the projection. """ - print_str = f'---\n' - print_str += f'PyProjection\n' - print_str += f'--- Projection Geometry:\n' - print_str += f'Detector Posiion [mm]: \t {self.detector_postion_mm.tolist()} \n' - print_str += f'Focal Spot [mm]: \t {self.focal_spot_mm.tolist()} \n' - print_str += f'Detector Orientation [quad]: \t {self.detector_orientation_quad.tolist()} \n' - print_str += f'Focal Spot Orientation [quad]: \t {self.focal_spot_orientation_quad.tolist()} \n' - print_str += f'---\n\n' + print_str = "---\n" + print_str += "PyProjection\n" + print_str += "--- Projection Geometry:\n" + print_str += ( + f"Detector Posiion [mm]: \t {self.detector_postion_mm.tolist()} \n" + ) + print_str += ( + f"Focal Spot [mm]: \t {self.focal_spot_mm.tolist()} \n" + ) + print_str += f"Detector Orientation [quad]: \t {self.detector_orientation_quad.tolist()} \n" + print_str += f"Focal Spot Orientation [quad]: \t {self.focal_spot_orientation_quad.tolist()} \n" + print_str += "---\n\n" return print_str @classmethod @@ -151,25 +214,40 @@ class PyProjection(PyProjectionGeometry): Returns: PyProjection: An instance initialized from the ROS message. """ - focal_spot_mm = np.array([msg.projection_geometry.focal_spot_postion_mm.x, - msg.projection_geometry.focal_spot_postion_mm.y, - msg.projection_geometry.focal_spot_postion_mm.z,]) - - detector_center_mm = np.array([msg.projection_geometry.detector_postion_mm.x, - msg.projection_geometry.detector_postion_mm.y, - msg.projection_geometry.detector_postion_mm.z,]) - - detector_orientation_quad = np.array([msg.projection_geometry.detector_orientation_quad.x, - msg.projection_geometry.detector_orientation_quad.y, - msg.projection_geometry.detector_orientation_quad.z, - msg.projection_geometry.detector_orientation_quad.w]) - - focal_spot_orientation = np.array([msg.projection_geometry.focal_spot_orientation_quad.x, - msg.projection_geometry.focal_spot_orientation_quad.y, - msg.projection_geometry.focal_spot_orientation_quad.z, - msg.projection_geometry.focal_spot_orientation_quad.w]) - - + focal_spot_mm = np.array( + [ + msg.projection_geometry.focal_spot_postion_mm.x, + msg.projection_geometry.focal_spot_postion_mm.y, + msg.projection_geometry.focal_spot_postion_mm.z, + ] + ) + + detector_center_mm = np.array( + [ + msg.projection_geometry.detector_postion_mm.x, + msg.projection_geometry.detector_postion_mm.y, + msg.projection_geometry.detector_postion_mm.z, + ] + ) + + detector_orientation_quad = np.array( + [ + msg.projection_geometry.detector_orientation_quad.x, + msg.projection_geometry.detector_orientation_quad.y, + msg.projection_geometry.detector_orientation_quad.z, + msg.projection_geometry.detector_orientation_quad.w, + ] + ) + + focal_spot_orientation = np.array( + [ + msg.projection_geometry.focal_spot_orientation_quad.x, + msg.projection_geometry.focal_spot_orientation_quad.y, + msg.projection_geometry.focal_spot_orientation_quad.z, + msg.projection_geometry.focal_spot_orientation_quad.w, + ] + ) + detector_heigth_mm = msg.detector_heigth_mm detector_width_mm = msg.detector_width_mm frame_id = msg.projection_geometry.header.frame_id @@ -180,10 +258,20 @@ class PyProjection(PyProjectionGeometry): current_ua = msg.current_ua exposure_time_ms = msg.exposure_time_ms - return cls(focal_spot_mm, detector_center_mm, detector_orientation_quad, image, - detector_heigth_mm, detector_width_mm, frame_id, focal_spot_orientation, - voltage_kv, current_ua, exposure_time_ms) - + return cls( + focal_spot_mm, + detector_center_mm, + detector_orientation_quad, + image, + detector_heigth_mm, + detector_width_mm, + frame_id, + focal_spot_orientation, + voltage_kv, + current_ua, + exposure_time_ms, + ) + def as_message(self) -> Projection: """ Converts the PyProjection instance to a ROS message. @@ -197,7 +285,9 @@ class PyProjection(PyProjectionGeometry): self.focal_spot_mm = self.focal_spot_mm.reshape((3,)) self.detector_postion_mm = self.detector_postion_mm.reshape((3,)) self.detector_orientation_quad = self.detector_orientation_quad.reshape((4,)) - self.focal_spot_orientation_quad = self.focal_spot_orientation_quad.reshape((4,)) + self.focal_spot_orientation_quad = self.focal_spot_orientation_quad.reshape( + (4,) + ) projection_geometry.focal_spot_postion_mm.x = float(self.focal_spot_mm[0]) projection_geometry.focal_spot_postion_mm.y = float(self.focal_spot_mm[1]) @@ -207,18 +297,34 @@ class PyProjection(PyProjectionGeometry): projection_geometry.detector_postion_mm.y = float(self.detector_postion_mm[1]) projection_geometry.detector_postion_mm.z = float(self.detector_postion_mm[2]) - projection_geometry.detector_orientation_quad.x = float(self.detector_orientation_quad[0]) - projection_geometry.detector_orientation_quad.y = float(self.detector_orientation_quad[1]) - projection_geometry.detector_orientation_quad.z = float(self.detector_orientation_quad[2]) - projection_geometry.detector_orientation_quad.w = float(self.detector_orientation_quad[3]) - - projection_geometry.focal_spot_orientation_quad.x = float(self.focal_spot_orientation_quad[0]) - projection_geometry.focal_spot_orientation_quad.y = float(self.focal_spot_orientation_quad[1]) - projection_geometry.focal_spot_orientation_quad.z = float(self.focal_spot_orientation_quad[2]) - projection_geometry.focal_spot_orientation_quad.w = float(self.focal_spot_orientation_quad[3]) + projection_geometry.detector_orientation_quad.x = float( + self.detector_orientation_quad[0] + ) + projection_geometry.detector_orientation_quad.y = float( + self.detector_orientation_quad[1] + ) + projection_geometry.detector_orientation_quad.z = float( + self.detector_orientation_quad[2] + ) + projection_geometry.detector_orientation_quad.w = float( + self.detector_orientation_quad[3] + ) + + projection_geometry.focal_spot_orientation_quad.x = float( + self.focal_spot_orientation_quad[0] + ) + projection_geometry.focal_spot_orientation_quad.y = float( + self.focal_spot_orientation_quad[1] + ) + projection_geometry.focal_spot_orientation_quad.z = float( + self.focal_spot_orientation_quad[2] + ) + projection_geometry.focal_spot_orientation_quad.w = float( + self.focal_spot_orientation_quad[3] + ) message.projection_geometry = projection_geometry - message.image = ros2_numpy.msgify(Image, self.image, 'mono16') + message.image = ros2_numpy.msgify(Image, self.image, "mono16") message.detector_heigth_mm = self.detector_heigth_mm message.detector_width_mm = self.detector_width_mm @@ -230,28 +336,46 @@ class PyProjection(PyProjectionGeometry): message.exposure_time_ms = self.exposure_time_ms return message - + def get_projection_geometry(self) -> PyProjectionGeometry: - return PyProjectionGeometry(self.focal_spot_mm, - self.detector_postion_mm, - self.detector_orientation_quad, - self.frame_id, - self.focal_spot_orientation_quad) - + return PyProjectionGeometry( + self.focal_spot_mm, + self.detector_postion_mm, + self.detector_orientation_quad, + self.frame_id, + self.focal_spot_orientation_quad, + ) + @property def detector_heigth_px(self) -> int: return self.image.shape[0] - + @property def detector_width_px(self) -> int: return self.image.shape[1] - + @property def pixel_pitch_x_mm(self) -> float: return self.detector_width_mm / self.detector_width_px - + @property def pixel_pitch_y_mm(self) -> float: return self.detector_heigth_mm / self.detector_heigth_px - - + + @property + def to_cera_transformation_matrix(self) -> np.ndarray: + transformation_matrix = np.eye(3) + # [mm] -> [px] + transformation_matrix[2, 2] = 1.0 / self.pixel_pitch_x_mm + transformation_matrix[0, 2] = -self.detector_width_px / 2.0 + transformation_matrix[1, 2] = -self.detector_heigth_px / 2.0 + + return transformation_matrix + + @property + def projection_matrix_cera_px(self): + projection_geometry = self.get_projection_geometry() + return ( + np.linalg.inv(self.to_cera_transformation_matrix) + @ projection_geometry.projection_matrix + ) diff --git a/rq_controller/common/projection_geometry.py b/rq_controller/common/projection_geometry.py index 61bc50099cc3dba82573915d7c58003f93ddbbb0..2c94ff65bd8f5e8429b5135155ba294e13ef5fc0 100644 --- a/rq_controller/common/projection_geometry.py +++ b/rq_controller/common/projection_geometry.py @@ -2,9 +2,11 @@ from __future__ import annotations import numpy as np -from rq_interfaces.msg import ProjectionGeometry, Projection +from rq_interfaces.msg import ProjectionGeometry +from scipy.spatial.transform import Rotation -class PyProjectionGeometry(): + +class PyProjectionGeometry: """ Represents the geometry of a projection including the focal spot and detector position and orientation. @@ -16,10 +18,14 @@ class PyProjectionGeometry(): frame_id (str): Frame ID for the projection geometry. """ - def __init__(self, focal_spot_mm: np.ndarray, detector_postion_mm: np.ndarray, - detector_orientation_quad: np.ndarray, - frame_id: str = 'object', focal_spot_orientation_quad: np.ndarray = np.array([0., 0., 0, 1.]) - ) -> None: + def __init__( + self, + focal_spot_mm: np.ndarray, + detector_postion_mm: np.ndarray, + detector_orientation_quad: np.ndarray, + frame_id: str = "object", + focal_spot_orientation_quad: np.ndarray = np.array([0.0, 0.0, 0, 1.0]), + ) -> None: """ Initializes a PyProjectionGeometry instance. @@ -45,9 +51,11 @@ class PyProjectionGeometry(): Returns: PyProjectionGeometry: A dummy instance with default values. """ - return cls(np.array([1., 0., 0]), - np.array([-1., 0., 0]), - np.array([0., 0., 0, 1.])) + return cls( + np.array([1.0, 0.0, 0]), + np.array([-1.0, 0.0, 0]), + np.array([0.0, 0.0, 0, 1.0]), + ) @classmethod def from_message(cls, msg: ProjectionGeometry): @@ -60,28 +68,50 @@ class PyProjectionGeometry(): Returns: PyProjectionGeometry: An instance initialized from the ROS message. """ - focal_spot_mm = np.array([msg.focal_spot_postion_mm.x, - msg.focal_spot_postion_mm.y, - msg.focal_spot_postion_mm.z,]) - - detector_center_mm = np.array([msg.detector_postion_mm.x, - msg.detector_postion_mm.y, - msg.detector_postion_mm.z,]) - - detector_orientation_quad = np.array([msg.detector_orientation_quad.x, - msg.detector_orientation_quad.y, - msg.detector_orientation_quad.z, - msg.detector_orientation_quad.w]) - - focal_spot_orientation = np.array([msg.focal_spot_orientation_quad.x, - msg.focal_spot_orientation_quad.y, - msg.focal_spot_orientation_quad.z, - msg.focal_spot_orientation_quad.w]) - + focal_spot_mm = np.array( + [ + msg.focal_spot_postion_mm.x, + msg.focal_spot_postion_mm.y, + msg.focal_spot_postion_mm.z, + ] + ) + + detector_center_mm = np.array( + [ + msg.detector_postion_mm.x, + msg.detector_postion_mm.y, + msg.detector_postion_mm.z, + ] + ) + + detector_orientation_quad = np.array( + [ + msg.detector_orientation_quad.x, + msg.detector_orientation_quad.y, + msg.detector_orientation_quad.z, + msg.detector_orientation_quad.w, + ] + ) + + focal_spot_orientation = np.array( + [ + msg.focal_spot_orientation_quad.x, + msg.focal_spot_orientation_quad.y, + msg.focal_spot_orientation_quad.z, + msg.focal_spot_orientation_quad.w, + ] + ) + frame_id = msg.header.frame_id - return cls(focal_spot_mm, detector_center_mm, detector_orientation_quad, frame_id, focal_spot_orientation) - + return cls( + focal_spot_mm, + detector_center_mm, + detector_orientation_quad, + frame_id, + focal_spot_orientation, + ) + def as_message(self) -> ProjectionGeometry: """ Converts the PyProjectionGeometry instance to a ROS message. @@ -111,4 +141,30 @@ class PyProjectionGeometry(): message.header.frame_id = self.frame_id - return message \ No newline at end of file + return message + + @property + def detector_rotation_matrix(self) -> np.ndarray: + return Rotation.from_quat(self.detector_orientation_quad).as_matrix() + + @property + def detector_horizontal_vector(self) -> np.ndarray: + return self.detector_rotation_matrix[:, 0] + + @property + def detector_vertical_vector(self) -> np.ndarray: + return self.detector_rotation_matrix[:, 1] + + @property + def projection_matrix(self) -> np.ndarray: + p3x3 = np.vstack( + [ + self.detector_horizontal_vector, + self.detector_vertical_vector, + self.detector_postion_mm - self.focal_spot_mm, + ] + ).T + p3x3_inv = np.linalg.inv(p3x3) + p4 = (p3x3_inv @ (-self.focal_spot_mm)).reshape((3, 1)) + matrix = np.concatenate([p3x3_inv, p4], 1).reshape((3, 4)) + return matrix diff --git a/rq_controller/common/region_of_intrest.py b/rq_controller/common/region_of_intrest.py index 4a71fda142cd0724897577ca7ffbed0c60c98e71..e7cbe2f4c1d2c4d3b7ba53033b2bc5581e1b27fe 100644 --- a/rq_controller/common/region_of_intrest.py +++ b/rq_controller/common/region_of_intrest.py @@ -6,7 +6,7 @@ from rq_interfaces.msg import RegionOfIntrest from visualization_msgs.msg import Marker -class PyRegionOfIntrest(): +class PyRegionOfIntrest: """ Represents a region of interest (ROI) with center points, dimensions, and resolution. @@ -17,8 +17,13 @@ class PyRegionOfIntrest(): resolution_mm (np.ndarray): Resolution of the ROIs in millimeters. """ - def __init__(self, center_points_mm: np.ndarray, dimensions_mm: np.ndarray, frame_id: str = 'object', - resolution_mm: np.ndarray = np.array([0.1, 0.1, 0.1])): + def __init__( + self, + center_points_mm: np.ndarray, + dimensions_mm: np.ndarray, + frame_id: str = "object", + resolution_mm: np.ndarray = np.array([0.1, 0.1, 0.1]), + ): """ Initializes a PyRegionOfIntrest instance. @@ -43,8 +48,7 @@ class PyRegionOfIntrest(): PyRegionOfIntrest: A dummy instance with random values. """ - return cls((np.random.random((3, )) - 0.5) * 20., - np.random.random((3, )) * 10.) + return cls((np.random.random((3,)) - 0.5) * 20.0, np.random.random((3,)) * 10.0) @classmethod def from_message(cls, msg: RegionOfIntrest): @@ -64,22 +68,20 @@ class PyRegionOfIntrest(): for roi in msg.region_of_intrest_stack.markers: roi: Marker center_points_mm.append( - np.array([roi.pose.position.x, - roi.pose.position.y, - roi.pose.position.z])) - - dimensions_mm.append( - np.array([roi.scale.x, - roi.scale.y, - roi.scale.z])) - + np.array( + [roi.pose.position.x, roi.pose.position.y, roi.pose.position.z] + ) + ) + + dimensions_mm.append(np.array([roi.scale.x, roi.scale.y, roi.scale.z])) + frame = roi.header.frame_id - resolution_mm = np.array([msg.resolution.x, - msg.resolution.y, - msg.resolution.z]) - - return cls(np.array(center_points_mm), np.array(dimensions_mm), frame, resolution_mm) - + resolution_mm = np.array([msg.resolution.x, msg.resolution.y, msg.resolution.z]) + + return cls( + np.array(center_points_mm), np.array(dimensions_mm), frame, resolution_mm + ) + @property def number_of_rois(self) -> int: """ @@ -90,7 +92,7 @@ class PyRegionOfIntrest(): """ return self.center_points_mm.shape[0] - + @property def shape(self) -> tuple: """ @@ -102,7 +104,7 @@ class PyRegionOfIntrest(): shape = self.dimensions_mm[0] // self.resolution_mm[0] return (int(shape[0]), int(shape[1]), int(shape[2])) - + def as_message(self) -> RegionOfIntrest: """ Converts the PyRegionOfIntrest instance to a ROS message. @@ -127,8 +129,8 @@ class PyRegionOfIntrest(): roi.header.frame_id = self.frame_id - roi_list.append(roi) - + roi_list.append(roi) + message.region_of_intrest_stack.markers = roi_list message.resolution.x = float(self.resolution_mm[0][0]) @@ -136,7 +138,7 @@ class PyRegionOfIntrest(): message.resolution.z = float(self.resolution_mm[0][2]) return message - + def get_grid(self, indice: int = 0) -> np.ndarray: """ Generates a grid of points within the ROI. @@ -148,21 +150,19 @@ class PyRegionOfIntrest(): np.ndarray: A grid of points within the ROI. """ - start = self.center_points_mm[indice] - (self.dimensions_mm[indice] / 2.) - end = self.center_points_mm[indice] + (self.dimensions_mm[indice] / 2.) - + start = self.center_points_mm[indice] - (self.dimensions_mm[indice] / 2.0) + end = self.center_points_mm[indice] + (self.dimensions_mm[indice] / 2.0) + x_ = np.linspace(start[0], end[0], self.shape[0]) y_ = np.linspace(start[1], end[1], self.shape[1]) z_ = np.linspace(start[2], end[2], self.shape[2]) - x, y, z = np.meshgrid(x_, y_, z_, indexing='ij') + x, y, z = np.meshgrid(x_, y_, z_, indexing="ij") + + return np.concatenate( + (np.expand_dims(x, -1), np.expand_dims(y, -1), np.expand_dims(z, -1)), -1 + ) - return np.concatenate(( - np.expand_dims(x, -1), - np.expand_dims(y, -1), - np.expand_dims(z, -1)), - -1) - @staticmethod def next_neighbor(grid_mm: np.ndarray, point_mm: np.ndarray) -> np.ndarray: """ @@ -175,13 +175,13 @@ class PyRegionOfIntrest(): Returns: np.ndarray: The indices of the nearest neighbor in the grid. """ - + x = grid_mm[:, 0, 0, 0] y = grid_mm[0, :, 0, 1] z = grid_mm[0, 0, :, 2] - xx = int(np.argmin((x - point_mm[0])**2)) - yy = int(np.argmin((y - point_mm[1])**2)) - zz = int(np.argmin((z - point_mm[2])**2)) + xx = int(np.argmin((x - point_mm[0]) ** 2)) + yy = int(np.argmin((y - point_mm[1]) ** 2)) + zz = int(np.argmin((z - point_mm[2]) ** 2)) - return np.array([xx, yy, zz], dtype=np.int32) \ No newline at end of file + return np.array([xx, yy, zz], dtype=np.int32) diff --git a/rq_controller/common/volume.py b/rq_controller/common/volume.py index ac1b64568c12ed0a0688cf32bdcc7f72e971c618..b63aea96f365da6185cd84503c194d0ff789a197 100644 --- a/rq_controller/common/volume.py +++ b/rq_controller/common/volume.py @@ -19,7 +19,7 @@ class VOLUME_TYPES(IntEnum): UINT_8 = 1 -class PyVolume(): +class PyVolume: """ Represents a volumetric dataset with an associated region of interest. @@ -29,7 +29,9 @@ class PyVolume(): data_typ (VOLUME_TYPES): The data type of the volume. """ - def __init__(self, array: ndarray, roi: PyRegionOfIntrest, data_type: VOLUME_TYPES = ...): + def __init__( + self, np_array: ndarray, roi: PyRegionOfIntrest, data_type: VOLUME_TYPES = ... + ): """ Initializes a PyVolume instance. @@ -40,7 +42,7 @@ class PyVolume(): """ self.roi = roi - self.array = array + self.array = np_array self.data_typ = data_type @staticmethod @@ -60,8 +62,8 @@ class PyVolume(): elif volume_type == VOLUME_TYPES.UINT_8: return np.uint8 else: - raise ValueError('Datatype is not implemented') - + raise ValueError("Datatype is not implemented") + @staticmethod def enum_to_numpify(volume_type: VOLUME_TYPES) -> str: """ @@ -75,14 +77,14 @@ class PyVolume(): """ if volume_type == VOLUME_TYPES.UINT_16: - return 'mono16' + return "mono16" elif volume_type == VOLUME_TYPES.UINT_8: - return 'mono8' + return "mono8" else: - raise ValueError('Datatype is not implemented') - + raise ValueError("Datatype is not implemented") + @classmethod - def dummy(cls) -> 'PyVolume': + def dummy(cls) -> "PyVolume": """ Creates a dummy instance of PyVolume for testing. @@ -91,12 +93,12 @@ class PyVolume(): """ roi = PyRegionOfIntrest.dummy() - array = np.random.randint(0, 255, size=roi.shape) + np_array = np.random.randint(0, 255, size=roi.shape) data_type = VOLUME_TYPES.UINT_8 - return cls(array, roi, data_type) + return cls(np_array, roi, data_type) @classmethod - def from_message(cls, msg: Volume) -> 'PyVolume': + def from_message(cls, msg: Volume) -> "PyVolume": """ Creates an instance of PyVolume from a ROS message. @@ -108,33 +110,34 @@ class PyVolume(): """ roi: Marker = msg.grid.region_of_intrest_stack.markers[0] - center_points_mm = np.array([ - roi.pose.position.x, - roi.pose.position.y, - roi.pose.position.z]) - - dimensions_mm = np.array([ - roi.scale.x, - roi.scale.y, - roi.scale.z]) - + center_points_mm = np.array( + [roi.pose.position.x, roi.pose.position.y, roi.pose.position.z] + ) + + dimensions_mm = np.array([roi.scale.x, roi.scale.y, roi.scale.z]) + frame_id = roi.header.frame_id - resolution_mm = np.array([ - msg.grid.resolution.x, - msg.grid.resolution.y, - msg.grid.resolution.z]) - - py_roi = PyRegionOfIntrest(center_points_mm, dimensions_mm, frame_id, resolution_mm) - + resolution_mm = np.array( + [msg.grid.resolution.x, msg.grid.resolution.y, msg.grid.resolution.z] + ) + + py_roi = PyRegionOfIntrest( + center_points_mm, dimensions_mm, frame_id, resolution_mm + ) + data_typ = msg.datatype - array = np.zeros(py_roi.shape, dtype=cls.get_data_type(data_typ)) + np_array = np.zeros(py_roi.shape, dtype=cls.get_data_type(data_typ)) for i, slice in enumerate(msg.slices): - array[:, :, i] = ros2_numpy.numpify(slice).reshape((py_roi.shape[0], py_roi.shape[1])).astype(cls.get_data_type(data_typ)) + np_array[:, :, i] = ( + ros2_numpy.numpify(slice) + .reshape((py_roi.shape[0], py_roi.shape[1])) + .astype(cls.get_data_type(data_typ)) + ) + + return cls(np_array, py_roi, data_typ) - return cls(array, py_roi, data_typ) - def as_message(self) -> Volume: """ Converts the PyVolume instance to a ROS message. @@ -151,12 +154,15 @@ class PyVolume(): for z in range(self.array.shape[2]): message.slices.append( - ros2_numpy.msgify(Image, - self.array[:, :, z].astype(self.get_data_type(self.data_typ)), - self.enum_to_numpify(self.data_typ))) - + ros2_numpy.msgify( + Image, + self.array[:, :, z].astype(self.get_data_type(self.data_typ)), + self.enum_to_numpify(self.data_typ), + ) + ) + return message - + @property def shape(self): """ @@ -165,10 +171,5 @@ class PyVolume(): Returns: tuple: The shape of the volume. """ - - return self.array.shape - - - - + return self.array.shape diff --git a/rq_controller/rq_workflow.py b/rq_controller/rq_workflow.py index f79cd440019c0a0cc0b22e25ef61bab92c33a24b..d31575991551fb7e516b3950ac7494f103edbf96 100644 --- a/rq_controller/rq_workflow.py +++ b/rq_controller/rq_workflow.py @@ -5,55 +5,52 @@ from rq_ddetection.defect_detection_client import DefectDetectionClient from rq_reconstruction.reconstruction_client import ReconstructionClient from rq_trajectory.trajectory_optimization_client import TrajectoryOptimizationClient -from rq_controller.common import PyProjection, PyRegionOfIntrest, PyVolume +from rq_controller.common import PyProjection, PyRegionOfIntrest import rclpy -from rclpy.node import Node -from scipy.spatial.transform import Rotation -import numpy as np - -class WorkflowNode(): +class WorkflowNode: def __init__(self): - self.hardware_interface = HardwareClient() self.defect_detection_interface = DefectDetectionClient() self.trajectory = TrajectoryOptimizationClient() self.reconstruction = ReconstructionClient() - - def defect_detection(self, projections: list[PyProjection]) -> list[PyRegionOfIntrest]: + def defect_detection( + self, projections: list[PyProjection] + ) -> list[PyRegionOfIntrest]: future = self.defect_detection_interface.multi_projection_defect_detection( - projections) - + projections + ) + rclpy.spin_until_future_complete(self.defect_detection_interface, future) response = future.result() roi_list = self.defect_detection_interface.response_2_py(response) return roi_list - + def aquire_projection(self, projection: PyProjection): future = self.hardware_interface.aquire_projection(projection) - + rclpy.spin_until_future_complete(self.hardware_interface, future) response = future.result() projection = self.hardware_interface.projection_response_2_py(response) return projection - + def check_reachability(self, projection: PyProjection): future = self.hardware_interface.check_reachability(projection) - + rclpy.spin_until_future_complete(self.hardware_interface, future) response = future.result() reached, cost, _ = self.hardware_interface.reachability_response_2_py(response) return reached, cost - + def get_volume(self, projection_stack: list[PyProjection], roi: PyRegionOfIntrest): future = self.reconstruction.get_volume(projection_stack, roi) @@ -61,7 +58,7 @@ class WorkflowNode(): response = future.result() return self.reconstruction.response_2_py(response) - + def trajectory_update(self, current_projection: PyProjection): future = self.trajectory.trajectory_update(current_projection) @@ -69,43 +66,45 @@ class WorkflowNode(): response = future.result() return self.trajectory.trajectory_update_response_2_py(response) - + def trajectory_initialize(self, roi: list[PyRegionOfIntrest]): future = self.trajectory.trajectory_initialize(roi) rclpy.spin_until_future_complete(self.trajectory, future) - def main(): - print('This minimal example need the nodes described in the rq_workflow/launch/echo_launch.py script.') + print( + "This minimal example need the nodes described in the rq_workflow/launch/echo_launch.py script." + ) rclpy.init() workflow = WorkflowNode() - print('Aquire some scout views...') + print("Aquire some scout views...") projection_list = list() for i in range(5): - print(f' - At projection {i}') + print(f" - At projection {i}") scan_pose = PyProjection.dummy() scan_pose.focal_spot_mm += i projection_list.append(workflow.aquire_projection(scan_pose)) - print('Detect errors in scout views ...') + print("Detect errors in scout views ...") roi = workflow.defect_detection(projection_list) - print('Init traj opt ...') + print("Init traj opt ...") workflow.trajectory_initialize(roi) - print('Get next scan pose from traj opt ...') + print("Get next scan pose from traj opt ...") next_scan_pose, finished = workflow.trajectory_update(projection_list[-1]) - print('Aquire scan pose ...') + print("Aquire scan pose ...") projection_list.append(workflow.aquire_projection(next_scan_pose)) - print('Resconstruct projections ...') - volume = workflow.get_volume(projection_list, roi[0]) + print("Resconstruct projections ...") + workflow.get_volume(projection_list, roi[0]) + + print("FINISHED !!!") - print('FINISHED !!!') -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/rq_controller/tf2/static_broadcaster.py b/rq_controller/tf2/static_broadcaster.py index 39f5cd972ecb46ff8c4af9c271af36fd7b230429..5becb1f72ca47683978a661e66e639ffa0f4abbb 100644 --- a/rq_controller/tf2/static_broadcaster.py +++ b/rq_controller/tf2/static_broadcaster.py @@ -1,9 +1,5 @@ -import math -import sys - from geometry_msgs.msg import TransformStamped -import numpy as np import rclpy from rclpy.node import Node @@ -20,13 +16,22 @@ class StaticFramePublisher(Node): time. """ - def __init__(self, frame, position, quaternion, name: str = 'static_broadcaster', parent_frame: str = 'world'): + def __init__( + self, + frame, + position, + quaternion, + name: str = "static_broadcaster", + parent_frame: str = "world", + ): super().__init__(name) self.tf_static_broadcaster = StaticTransformBroadcaster(self) # Publish static transforms once at startup - self.make_transforms(frame, position.split(" "), quaternion.split(" "), parent_frame) + self.make_transforms( + frame, position.split(" "), quaternion.split(" "), parent_frame + ) def make_transforms(self, frame, position, quaternion, parent_frame): t = TransformStamped() @@ -47,27 +52,22 @@ class StaticFramePublisher(Node): def main(args=None): - logger = rclpy.logging.get_logger('logger') + rclpy.logging.get_logger("logger") rclpy.init(args=args) - node = rclpy.create_node('static_broadcaster') - parent_frame = node.declare_parameter('parent_frame', 'world').value - frame = node.declare_parameter('frame', 'object').value - name = node.declare_parameter('name', 'static_broadcaster').value - position = node.declare_parameter('position', '0 0 0').value - quaternion = node.declare_parameter('quaternion', '0 0 0 1').value + node = rclpy.create_node("static_broadcaster") + parent_frame = node.declare_parameter("parent_frame", "world").value + frame = node.declare_parameter("frame", "object").value + name = node.declare_parameter("name", "static_broadcaster").value + position = node.declare_parameter("position", "0 0 0").value + quaternion = node.declare_parameter("quaternion", "0 0 0 1").value + + node = StaticFramePublisher(frame, position, quaternion, name, parent_frame) - node = StaticFramePublisher( - frame, - position, - quaternion, - name, - parent_frame) - try: rclpy.spin(node) except KeyboardInterrupt: pass - rclpy.shutdown() \ No newline at end of file + rclpy.shutdown() diff --git a/setup.py b/setup.py index 0b679c3879cb818be91c24a9ee29352538ed4c13..37ee73e53748fa7802ca623671765e6f3190bb12 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,25 @@ from setuptools import find_packages, setup -package_name = 'rq_controller' +package_name = "rq_controller" setup( name=package_name, - version='0.0.0', - packages=find_packages(exclude=['test']), + version="0.0.0", + packages=find_packages(exclude=["test"]), data_files=[ - ('share/ament_index/resource_index/packages', - ['resource/' + package_name]), - ('share/' + package_name, ['package.xml']), + ("share/ament_index/resource_index/packages", ["resource/" + package_name]), + ("share/" + package_name, ["package.xml"]), ], - install_requires=['setuptools', - 'Pillow', - 'pyometiff'], + install_requires=["setuptools", "Pillow", "pyometiff"], zip_safe=True, - maintainer='root', - maintainer_email='simon.wittl@th-deg.de', - description='TODO: Package description', - license='TODO: License declaration', - tests_require=['pytest'], + maintainer="root", + maintainer_email="simon.wittl@th-deg.de", + description="TODO: Package description", + license="TODO: License declaration", + tests_require=["pytest"], entry_points={ - 'console_scripts': [ - 'tf_static_broadcaster = rq_controller.tf2.static_broadcaster:main', + "console_scripts": [ + "tf_static_broadcaster = rq_controller.tf2.static_broadcaster:main", ], }, ) diff --git a/test/__pycache__/test_base_loader.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_base_loader.cpython-38-pytest-8.2.2.pyc index 279c0db1aec0f0f5b456424894890499d32a6041..49fe03382c5c64fcf5b5890c4cdc5cd6c09d5e4d 100644 Binary files a/test/__pycache__/test_base_loader.cpython-38-pytest-8.2.2.pyc and b/test/__pycache__/test_base_loader.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_base_writer.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_base_writer.cpython-38-pytest-8.2.2.pyc index 29e343a89d1a629bacda3ebb8b0f24f419a5fc0e..1d58ab39dfe5e08797f2dc9b57d4f0e5bc24e22c 100644 Binary files a/test/__pycache__/test_base_writer.cpython-38-pytest-8.2.2.pyc and b/test/__pycache__/test_base_writer.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_projection.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_projection.cpython-38-pytest-8.2.2.pyc index 8413a88371ec487791bfb98a4279f60cf8b3edae..cb9728ada1db97c0497e0512d7c4c04df1beda91 100644 Binary files a/test/__pycache__/test_projection.cpython-38-pytest-8.2.2.pyc and b/test/__pycache__/test_projection.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_projection_geometry.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_projection_geometry.cpython-38-pytest-8.2.2.pyc index daecbc751bbc7609edc20368ee278eb37f545dfe..e30f6eb8b6c80bf2b52bcc7c18ee64568b1a801f 100644 Binary files a/test/__pycache__/test_projection_geometry.cpython-38-pytest-8.2.2.pyc and b/test/__pycache__/test_projection_geometry.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_region_of_intrest.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_region_of_intrest.cpython-38-pytest-8.2.2.pyc index 469b57f673dd3002e757812d9c29a8ea600abac0..cdd2ab3fba79761b42e2adb2ccc1cc222fb2eecd 100644 Binary files a/test/__pycache__/test_region_of_intrest.cpython-38-pytest-8.2.2.pyc and b/test/__pycache__/test_region_of_intrest.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_volume.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_volume.cpython-38-pytest-8.2.2.pyc index e4329085a29c878eed2392f2e4b4eda540f9cb6f..f39dd0ea1a983260f04c5ddbbe585720a430ed67 100644 Binary files a/test/__pycache__/test_volume.cpython-38-pytest-8.2.2.pyc and b/test/__pycache__/test_volume.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/test_base_loader.py b/test/test_base_loader.py index eb1c071d7ddd30b4e84083f32e5dd780461146e0..d00fcbcc6d8de3d1246659d4e709ffad8a9fd703 100644 --- a/test/test_base_loader.py +++ b/test/test_base_loader.py @@ -2,33 +2,39 @@ import pytest from rq_controller.common.io.loader import BaseDataLoader from pathlib import Path + @pytest.fixture def base_data_loader(): return BaseDataLoader( porjection_geometry_suffix="geometry_suffix", projection_suffix="projection_suffix", region_of_intrest_suffix="roi_suffix", - volume_suffix="volume_suffix" + volume_suffix="volume_suffix", ) + def test_initialization(base_data_loader: BaseDataLoader): assert base_data_loader.porjection_geometry_suffix == "geometry_suffix" assert base_data_loader.projection_suffix == "projection_suffix" assert base_data_loader.region_of_intrest_suffix == "roi_suffix" assert base_data_loader.volume_suffix == "volume_suffix" + def test_load_projection_geometry_not_implemented(base_data_loader: BaseDataLoader): with pytest.raises(NotImplementedError): base_data_loader.load_projection_geometry(Path("/fake/path")) + def test_load_projection_not_implemented(base_data_loader: BaseDataLoader): with pytest.raises(NotImplementedError): base_data_loader.load_projection(Path("/fake/path")) + def test_load_region_of_intrest_not_implemented(base_data_loader: BaseDataLoader): with pytest.raises(NotImplementedError): base_data_loader.load_region_of_intrest(Path("/fake/path")) + def test_load_volume_not_implemented(base_data_loader: BaseDataLoader): with pytest.raises(NotImplementedError): - base_data_loader.load_volume(Path("/fake/path")) \ No newline at end of file + base_data_loader.load_volume(Path("/fake/path")) diff --git a/test/test_base_writer.py b/test/test_base_writer.py index 3d41a2d95c8b45bda06cfa681c1dbc0286420aad..d0896dfaef115806a3f5eb317e8db93d92f429ff 100644 --- a/test/test_base_writer.py +++ b/test/test_base_writer.py @@ -1,91 +1,128 @@ import pytest from pathlib import Path -from rq_controller.common import PyProjection, PyProjectionGeometry, PyRegionOfIntrest, PyVolume +from rq_controller.common import ( + PyProjection, + PyProjectionGeometry, + PyRegionOfIntrest, + PyVolume, +) from rq_controller.common.io.writer import BaseDataWriter + @pytest.fixture def base_data_writer(): return BaseDataWriter( porjection_geometry_suffix=".geom", projection_suffix=".proj", region_of_intrest_suffix=".roi", - volume_suffix=".vol" + volume_suffix=".vol", ) + def test_initialization(base_data_writer): assert base_data_writer.porjection_geometry_suffix == ".geom" assert base_data_writer.projection_suffix == ".proj" assert base_data_writer.region_of_intrest_suffix == ".roi" assert base_data_writer.volume_suffix == ".vol" + def test_write_projection_geometry_not_implemented(base_data_writer): with pytest.raises(NotImplementedError): - base_data_writer.write_projection_geometry(Path("/fake/path"), PyProjectionGeometry.dummy()) + base_data_writer.write_projection_geometry( + Path("/fake/path"), PyProjectionGeometry.dummy() + ) + def test_write_projection_not_implemented(base_data_writer): with pytest.raises(NotImplementedError): base_data_writer.write_projection(Path("/fake/path"), PyProjection.dummy()) + def test_write_region_of_intrest_not_implemented(base_data_writer): with pytest.raises(NotImplementedError): - base_data_writer.write_region_of_intrest(Path("/fake/path"), PyRegionOfIntrest.dummy()) + base_data_writer.write_region_of_intrest( + Path("/fake/path"), PyRegionOfIntrest.dummy() + ) + def test_write_volume_not_implemented(base_data_writer): with pytest.raises(NotImplementedError): base_data_writer.write_volume(Path("/fake/path"), PyVolume.dummy()) + def test_get_next_projection_save_path(base_data_writer, tmp_path): (tmp_path / "projection_00001.proj").touch() expected_path = tmp_path / "projection_00002.proj" assert base_data_writer.get_next_projection_save_path(tmp_path) == expected_path + def test_get_projection_save_path_i(base_data_writer, tmp_path): expected_path = tmp_path / "projection_00005.proj" assert base_data_writer.get_projection_save_path_i(tmp_path, 5) == expected_path + def test_get_next_projection_geometry_save_path(base_data_writer, tmp_path): (tmp_path / "geometry_00001.geom").touch() expected_path = tmp_path / "geometry_00002.geom" - assert base_data_writer.get_next_projection_geometry_save_path(tmp_path) == expected_path + assert ( + base_data_writer.get_next_projection_geometry_save_path(tmp_path) + == expected_path + ) + def test_get_projection_geometry_save_path_i(base_data_writer, tmp_path): expected_path = tmp_path / "geometry_00005.geom" - assert base_data_writer.get_projection_geometry_save_path_i(tmp_path, 5) == expected_path + assert ( + base_data_writer.get_projection_geometry_save_path_i(tmp_path, 5) + == expected_path + ) + def test_get_next_region_of_intrest_save_path(base_data_writer, tmp_path): (tmp_path / "roi_00001.roi").touch() expected_path = tmp_path / "roi_00002.roi" - assert base_data_writer.get_next_region_of_intrest_save_path(tmp_path) == expected_path + assert ( + base_data_writer.get_next_region_of_intrest_save_path(tmp_path) == expected_path + ) + def test_get_region_of_intrest_save_path_i(base_data_writer, tmp_path): expected_path = tmp_path / "roi_00005.roi" - assert base_data_writer.get_region_of_intrest_save_path_i(tmp_path, 5) == expected_path + assert ( + base_data_writer.get_region_of_intrest_save_path_i(tmp_path, 5) == expected_path + ) + def test_get_next_volume_save_path(base_data_writer, tmp_path): (tmp_path / "volume_00001.vol").touch() expected_path = tmp_path / "volume_00002.vol" assert base_data_writer.get_next_volume_save_path(tmp_path) == expected_path + def test_get_volume_save_path_i(base_data_writer, tmp_path): expected_path = tmp_path / "volume_00005.vol" assert base_data_writer.get_volume_save_path_i(tmp_path, 5) == expected_path + def test_number_of_projection_geometries(base_data_writer, tmp_path): (tmp_path / "geometry_00001.geom").touch() (tmp_path / "geometry_00002.geom").touch() assert base_data_writer.number_of_projection_geometries(tmp_path) == 2 + def test_number_of_projections(base_data_writer, tmp_path): (tmp_path / "projection_00001.proj").touch() (tmp_path / "projection_00002.proj").touch() assert base_data_writer.number_of_projections(tmp_path) == 2 + def test_number_of_region_of_intrests(base_data_writer, tmp_path): (tmp_path / "roi_00001.roi").touch() (tmp_path / "roi_00002.roi").touch() assert base_data_writer.number_of_region_of_intrests(tmp_path) == 2 + def test_number_of_volumes(base_data_writer, tmp_path): (tmp_path / "volume_00001.vol").touch() (tmp_path / "volume_00002.vol").touch() - assert base_data_writer.number_of_volumes(tmp_path) == 2 \ No newline at end of file + assert base_data_writer.number_of_volumes(tmp_path) == 2 diff --git a/test/test_projection.py b/test/test_projection.py index 06075e4a0d53943545fb6d2dd87c8dc33b5ae880..0ccf38644f3f5c99dac9164cab1269ca086810de 100644 --- a/test/test_projection.py +++ b/test/test_projection.py @@ -1,10 +1,13 @@ import pytest import numpy as np -from rq_controller.common import PyProjection # Adjust this import according to your module's path +from rq_controller.common import ( + PyProjection, +) # Adjust this import according to your module's path from rq_interfaces.msg import Projection from sensor_msgs.msg import Image import ros2_numpy + @pytest.fixture def example_message(): msg = Projection() @@ -23,10 +26,10 @@ def example_message(): msg.projection_geometry.focal_spot_orientation_quad.z = 0.0 msg.projection_geometry.focal_spot_orientation_quad.w = 1.0 msg.projection_geometry.header.frame_id = "test_frame" - + image_array = np.zeros((10, 10), dtype=np.uint16) - msg.image = ros2_numpy.msgify(Image, image_array, encoding='mono16') - + msg.image = ros2_numpy.msgify(Image, image_array, encoding="mono16") + msg.detector_heigth_mm = 100.0 msg.detector_width_mm = 200.0 msg.voltage_kv = 120.0 @@ -34,6 +37,7 @@ def example_message(): msg.exposure_time_ms = 500.0 return msg + def test_initialization(): focal_spot_mm = np.array([1.0, 2.0, 3.0]) detector_postion_mm = np.array([4.0, 5.0, 6.0]) @@ -47,12 +51,28 @@ def test_initialization(): frame_id = "test_frame" focal_spot_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) - projection = PyProjection(focal_spot_mm, detector_postion_mm, detector_orientation_quad, image, detector_heigth_mm, detector_width_mm, frame_id, focal_spot_orientation_quad, voltage_kv, current_ua, exposure_time_ms) + projection = PyProjection( + focal_spot_mm, + detector_postion_mm, + detector_orientation_quad, + image, + detector_heigth_mm, + detector_width_mm, + frame_id, + focal_spot_orientation_quad, + voltage_kv, + current_ua, + exposure_time_ms, + ) assert np.array_equal(projection.focal_spot_mm, focal_spot_mm) assert np.array_equal(projection.detector_postion_mm, detector_postion_mm) - assert np.array_equal(projection.detector_orientation_quad, detector_orientation_quad) - assert np.array_equal(projection.focal_spot_orientation_quad, focal_spot_orientation_quad) + assert np.array_equal( + projection.detector_orientation_quad, detector_orientation_quad + ) + assert np.array_equal( + projection.focal_spot_orientation_quad, focal_spot_orientation_quad + ) assert np.array_equal(projection.image, image) assert projection.detector_heigth_mm == detector_heigth_mm assert projection.detector_width_mm == detector_width_mm @@ -61,25 +81,32 @@ def test_initialization(): assert projection.exposure_time_ms == exposure_time_ms assert projection.frame_id == frame_id + def test_dummy_method(): projection = PyProjection.dummy() assert np.array_equal(projection.focal_spot_mm, np.array([0.0, 100.0, 0.0])) assert np.array_equal(projection.detector_postion_mm, np.array([0.0, -100.0, 0.0])) - assert np.array_equal(projection.detector_orientation_quad, np.array([1.0, 0.0, 0.0, 1.0])) - assert np.array_equal(projection.image, np.zeros((10, 10), dtype=np.uint16)) + assert np.array_equal( + projection.detector_orientation_quad, np.array([1.0, 0.0, 0.0, 1.0]) + ) assert projection.detector_heigth_mm == 10.0 assert projection.detector_width_mm == 10.0 assert projection.frame_id == "object" + def test_from_message(example_message): projection = PyProjection.from_message(example_message) assert np.array_equal(projection.focal_spot_mm, np.array([10.0, 20.0, 30.0])) assert np.array_equal(projection.detector_postion_mm, np.array([40.0, 50.0, 60.0])) - assert np.array_equal(projection.detector_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0])) - assert np.array_equal(projection.focal_spot_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0])) - assert np.array_equal(projection.image, np.zeros((10, 10), dtype=np.uint16)) + assert np.array_equal( + projection.detector_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0]) + ) + assert np.array_equal( + projection.focal_spot_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0]) + ) + assert projection.detector_heigth_mm == 100.0 assert projection.detector_width_mm == 200.0 assert projection.voltage_kv == 120.0 @@ -87,6 +114,7 @@ def test_from_message(example_message): assert projection.exposure_time_ms == 500.0 assert projection.frame_id == "test_frame" + def test_as_message(example_message): projection = PyProjection.from_message(example_message) msg = projection.as_message() @@ -106,13 +134,16 @@ def test_as_message(example_message): assert msg.projection_geometry.focal_spot_orientation_quad.z == 0.0 assert msg.projection_geometry.focal_spot_orientation_quad.w == 1.0 assert msg.projection_geometry.header.frame_id == "test_frame" - assert np.array_equal(ros2_numpy.numpify(msg.image), np.zeros((10, 10), dtype=np.uint16)) + assert np.array_equal( + ros2_numpy.numpify(msg.image), np.zeros((10, 10), dtype=np.uint16) + ) assert msg.detector_heigth_mm == 100.0 assert msg.detector_width_mm == 200.0 assert msg.voltage_kv == 120.0 assert msg.current_ua == 150.0 assert msg.exposure_time_ms == 500.0 + def test_properties(): focal_spot_mm = np.array([1.0, 2.0, 3.0]) detector_postion_mm = np.array([4.0, 5.0, 6.0]) @@ -126,9 +157,21 @@ def test_properties(): frame_id = "test_frame" focal_spot_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) - projection = PyProjection(focal_spot_mm, detector_postion_mm, detector_orientation_quad, image, detector_heigth_mm, detector_width_mm, frame_id, focal_spot_orientation_quad, voltage_kv, current_ua, exposure_time_ms) + projection = PyProjection( + focal_spot_mm, + detector_postion_mm, + detector_orientation_quad, + image, + detector_heigth_mm, + detector_width_mm, + frame_id, + focal_spot_orientation_quad, + voltage_kv, + current_ua, + exposure_time_ms, + ) assert projection.detector_heigth_px == 20 assert projection.detector_width_px == 30 assert projection.pixel_pitch_x_mm == 200.0 / 30 - assert projection.pixel_pitch_y_mm == 100.0 / 20 \ No newline at end of file + assert projection.pixel_pitch_y_mm == 100.0 / 20 diff --git a/test/test_projection_geometry.py b/test/test_projection_geometry.py index 7d4c44de107de144cc90a0abda6004c6ded44014..e28d1d97b3228e4885c253c570b4ddf5298d71fe 100644 --- a/test/test_projection_geometry.py +++ b/test/test_projection_geometry.py @@ -1,29 +1,30 @@ import pytest import numpy as np from rq_interfaces.msg import ProjectionGeometry -from rq_controller.common import PyProjectionGeometry +from rq_controller.common import PyProjectionGeometry @pytest.fixture def example_message(): msg = ProjectionGeometry() - msg.focal_spot_postion_mm.x = 10. - msg.focal_spot_postion_mm.y = 20. - msg.focal_spot_postion_mm.z = 30. - msg.detector_postion_mm.x = 40. - msg.detector_postion_mm.y = 50. - msg.detector_postion_mm.z = 60. - msg.detector_orientation_quad.x = 0. - msg.detector_orientation_quad.y = 0. - msg.detector_orientation_quad.z = 0. - msg.detector_orientation_quad.w = 1. - msg.focal_spot_orientation_quad.x = 0. - msg.focal_spot_orientation_quad.y = 0. - msg.focal_spot_orientation_quad.z = 0. - msg.focal_spot_orientation_quad.w = 1. + msg.focal_spot_postion_mm.x = 10.0 + msg.focal_spot_postion_mm.y = 20.0 + msg.focal_spot_postion_mm.z = 30.0 + msg.detector_postion_mm.x = 40.0 + msg.detector_postion_mm.y = 50.0 + msg.detector_postion_mm.z = 60.0 + msg.detector_orientation_quad.x = 0.0 + msg.detector_orientation_quad.y = 0.0 + msg.detector_orientation_quad.z = 0.0 + msg.detector_orientation_quad.w = 1.0 + msg.focal_spot_orientation_quad.x = 0.0 + msg.focal_spot_orientation_quad.y = 0.0 + msg.focal_spot_orientation_quad.z = 0.0 + msg.focal_spot_orientation_quad.w = 1.0 msg.header.frame_id = "test_frame" return msg + def test_initialization(): focal_spot_mm = np.array([1.0, 2.0, 3.0]) detector_postion_mm = np.array([4.0, 5.0, 6.0]) @@ -31,22 +32,36 @@ def test_initialization(): frame_id = "test_frame" focal_spot_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) - geometry = PyProjectionGeometry(focal_spot_mm, detector_postion_mm, detector_orientation_quad, frame_id, focal_spot_orientation_quad) + geometry = PyProjectionGeometry( + focal_spot_mm, + detector_postion_mm, + detector_orientation_quad, + frame_id, + focal_spot_orientation_quad, + ) assert np.array_equal(geometry.focal_spot_mm, focal_spot_mm) assert np.array_equal(geometry.detector_postion_mm, detector_postion_mm) assert np.array_equal(geometry.detector_orientation_quad, detector_orientation_quad) - assert np.array_equal(geometry.focal_spot_orientation_quad, focal_spot_orientation_quad) + assert np.array_equal( + geometry.focal_spot_orientation_quad, focal_spot_orientation_quad + ) assert geometry.frame_id == frame_id + def test_dummy_method(): geometry = PyProjectionGeometry.dummy() assert np.array_equal(geometry.focal_spot_mm, np.array([1.0, 0.0, 0.0])) assert np.array_equal(geometry.detector_postion_mm, np.array([-1.0, 0.0, 0.0])) - assert np.array_equal(geometry.detector_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0])) + assert np.array_equal( + geometry.detector_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0]) + ) assert geometry.frame_id == "object" - assert np.array_equal(geometry.focal_spot_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0])) + assert np.array_equal( + geometry.focal_spot_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0]) + ) + def test_from_message(example_message): geometry = PyProjectionGeometry.from_message(example_message) @@ -57,6 +72,7 @@ def test_from_message(example_message): assert np.array_equal(geometry.focal_spot_orientation_quad, np.array([0, 0, 0, 1])) assert geometry.frame_id == "test_frame" + def test_as_message(example_message): geometry = PyProjectionGeometry.from_message(example_message) msg = geometry.as_message() diff --git a/test/test_region_of_intrest.py b/test/test_region_of_intrest.py index 678f0528d467c6a9c4f975df7ae10b446e73e067..67740759fe8d651a924d4dc32bd77755315817fd 100644 --- a/test/test_region_of_intrest.py +++ b/test/test_region_of_intrest.py @@ -1,13 +1,16 @@ import pytest import numpy as np -from rq_controller.common import PyRegionOfIntrest # Adjust this import according to your module's path +from rq_controller.common import ( + PyRegionOfIntrest, +) # Adjust this import according to your module's path from rq_interfaces.msg import RegionOfIntrest from visualization_msgs.msg import Marker + @pytest.fixture def example_message(): msg = RegionOfIntrest() - + marker = Marker() marker.pose.position.x = 10.0 marker.pose.position.y = 20.0 @@ -16,15 +19,16 @@ def example_message(): marker.scale.y = 50.0 marker.scale.z = 60.0 marker.header.frame_id = "test_frame" - + msg.region_of_intrest_stack.markers.append(marker) - + msg.resolution.x = 0.1 msg.resolution.y = 0.1 msg.resolution.z = 0.1 - + return msg + def test_initialization(): center_points_mm = np.array([[1.0, 2.0, 3.0]]) dimensions_mm = np.array([[4.0, 5.0, 6.0]]) @@ -38,6 +42,7 @@ def test_initialization(): assert np.array_equal(roi.resolution_mm, resolution_mm) assert roi.frame_id == frame_id + def test_dummy_method(): roi = PyRegionOfIntrest.dummy() @@ -45,6 +50,7 @@ def test_dummy_method(): assert roi.dimensions_mm.shape == (1, 3) assert roi.frame_id == "object" + def test_from_message(example_message): roi = PyRegionOfIntrest.from_message(example_message) @@ -53,6 +59,7 @@ def test_from_message(example_message): assert np.array_equal(roi.resolution_mm, np.array([[0.1, 0.1, 0.1]])) assert roi.frame_id == "test_frame" + def test_as_message(example_message): roi = PyRegionOfIntrest.from_message(example_message) msg = roi.as_message() @@ -67,26 +74,34 @@ def test_as_message(example_message): assert marker.scale.y == 50.0 assert marker.scale.z == 60.0 assert marker.header.frame_id == "test_frame" - + assert msg.resolution.x == 0.1 assert msg.resolution.y == 0.1 assert msg.resolution.z == 0.1 + def test_number_of_rois(example_message): roi = PyRegionOfIntrest.from_message(example_message) assert roi.number_of_rois == 1 + def test_shape(example_message): roi = PyRegionOfIntrest.from_message(example_message) assert roi.shape == (399, 499, 599) + def test_get_grid(example_message): roi = PyRegionOfIntrest.from_message(example_message) grid = roi.get_grid(0) assert grid.shape == (399, 499, 599, 3) - assert np.array_equal(grid[0, 0, 0], np.array([10.0 - 20.0, 20.0 - 25.0, 30.0 - 30.0])) - assert np.array_equal(grid[-1, -1, -1], np.array([10.0 + 20.0, 20.0 + 25.0, 30.0 + 30.0])) + assert np.array_equal( + grid[0, 0, 0], np.array([10.0 - 20.0, 20.0 - 25.0, 30.0 - 30.0]) + ) + assert np.array_equal( + grid[-1, -1, -1], np.array([10.0 + 20.0, 20.0 + 25.0, 30.0 + 30.0]) + ) + def test_next_neighbor(): grid_mm = np.zeros((20, 30, 40, 3)) diff --git a/test/test_volume.py b/test/test_volume.py index e2a8ff8e78facdcf65040525438f3b1975ae49df..a3c095b567cd88739e8ef01b40fcc66946010e5a 100644 --- a/test/test_volume.py +++ b/test/test_volume.py @@ -1,6 +1,9 @@ import pytest import numpy as np -from rq_controller.common import PyVolume, PyRegionOfIntrest # Adjust this import according to your module's path +from rq_controller.common import ( + PyVolume, + PyRegionOfIntrest, +) # Adjust this import according to your module's path from rq_controller.common.volume import VOLUME_TYPES from rq_interfaces.msg import Volume from visualization_msgs.msg import Marker @@ -22,19 +25,19 @@ def example_message(): marker.header.frame_id = "test_frame" msg.grid.region_of_intrest_stack.markers.append(marker) - - msg.grid.resolution.x = 1. - msg.grid.resolution.y = 1. - msg.grid.resolution.z = 1. + + msg.grid.resolution.x = 1.0 + msg.grid.resolution.y = 1.0 + msg.grid.resolution.z = 1.0 msg.datatype = VOLUME_TYPES.UINT_16 shape = (40, 50, 60) slices = np.random.randint(0, 65535, size=shape, dtype=np.uint16) for i in range(shape[2]): - image_msg = ros2_numpy.msgify(Image, slices[:, :, i], encoding='mono16') + image_msg = ros2_numpy.msgify(Image, slices[:, :, i], encoding="mono16") msg.slices.append(image_msg) - + return msg @@ -73,7 +76,7 @@ def test_as_message(example_message): for i, slice_msg in enumerate(msg.slices): slice_array = ros2_numpy.numpify(slice_msg) assert np.array_equal(slice_array, volume.array[:, :, i]) - + def test_get_data_type(): assert PyVolume.get_data_type(VOLUME_TYPES.UINT_16) == np.uint16 @@ -83,8 +86,8 @@ def test_get_data_type(): def test_enum_to_numpify(): - assert PyVolume.enum_to_numpify(VOLUME_TYPES.UINT_16) == 'mono16' - assert PyVolume.enum_to_numpify(VOLUME_TYPES.UINT_8) == 'mono8' + assert PyVolume.enum_to_numpify(VOLUME_TYPES.UINT_16) == "mono16" + assert PyVolume.enum_to_numpify(VOLUME_TYPES.UINT_8) == "mono8" with pytest.raises(ValueError): PyVolume.enum_to_numpify(-1) @@ -95,4 +98,3 @@ def test_shape(): volume = PyVolume(array, roi, VOLUME_TYPES.UINT_8) assert volume.shape == array.shape -