diff --git a/.vscode/settings.json b/.vscode/settings.json index 5626c31da66120d42e9ee0b15020095690f170b4..9a75dc7d3af5969eec4933d5304871bf3049c37c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,5 +12,10 @@ "C:\\ci\\ws\\install\\lib\\python3.8\\dist-packages", "C:\\dev\\ros2_humble\\Lib\\site-packages", "" - ] + ], + "python.testing.pytestArgs": [ + "test" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } \ No newline at end of file diff --git a/rq_controller/common/__pycache__/__init__.cpython-38.pyc b/rq_controller/common/__pycache__/__init__.cpython-38.pyc index b5a7ce552844d6e4a6c1828b88db4a6f0d8786a2..c4ac8808716f0af57d37032d44cd890bef1dd36e 100644 Binary files a/rq_controller/common/__pycache__/__init__.cpython-38.pyc and b/rq_controller/common/__pycache__/__init__.cpython-38.pyc differ diff --git a/rq_controller/common/__pycache__/projection.cpython-38.pyc b/rq_controller/common/__pycache__/projection.cpython-38.pyc index 84787faff2867b84f7673e065471b45fd6f6c070..292a7253e286283d2a9f7b7e8b5c8f92156b3bb6 100644 Binary files a/rq_controller/common/__pycache__/projection.cpython-38.pyc and b/rq_controller/common/__pycache__/projection.cpython-38.pyc differ diff --git a/rq_controller/common/__pycache__/projection_geometry.cpython-38.pyc b/rq_controller/common/__pycache__/projection_geometry.cpython-38.pyc index e7efc4a86b91246f1a0538ae704fa85f48a9b64b..a849f8bd44ad63633ea5991038e17486b43fa1a9 100644 Binary files a/rq_controller/common/__pycache__/projection_geometry.cpython-38.pyc and b/rq_controller/common/__pycache__/projection_geometry.cpython-38.pyc differ diff --git a/rq_controller/common/projection.py b/rq_controller/common/projection.py index ea5c1ee5010bf24e08ba4d471d4dc3c860de0c39..813f573acc921aa5949caccef5ab4c019c35c0b7 100644 --- a/rq_controller/common/projection.py +++ b/rq_controller/common/projection.py @@ -11,10 +11,39 @@ import ros2_numpy class PyProjection(PyProjectionGeometry): + """ + Represents a projection including image data and associated geometry. + + Attributes: + image (np.ndarray): The projection image data. + detector_heigth_mm (float): The height of the detector in millimeters. + detector_width_mm (float): The width of the detector in millimeters. + voltage_kv (float): The voltage in kilovolts. + current_ua (float): The current in microamperes. + exposure_time_ms (float): The exposure time in milliseconds. + """ + def __init__(self, focal_spot_mm: ndarray, detector_postion_mm: ndarray, detector_orientation_quad: ndarray, image: np.ndarray, detector_heigth_mm: float, detector_width_mm: float, frame_id: str = 'object', focal_spot_orientation_quad: np.ndarray = np.array([0., 0., 0, 1.]), voltage_kv: float = 100., current_ua: float = 100., exposure_time_ms: float = 1000.) -> None: + """ + Initializes a PyProjection instance. + + Args: + focal_spot_mm (ndarray): Position of the focal spot in millimeters. + detector_postion_mm (ndarray): Position of the detector in millimeters. + detector_orientation_quad (ndarray): Orientation of the detector as a quaternion. + image (np.ndarray): The projection image data. + detector_heigth_mm (float): The height of the detector in millimeters. + detector_width_mm (float): The width of the detector in millimeters. + frame_id (str): Frame ID for the projection geometry. Default is 'object'. + focal_spot_orientation_quad (np.ndarray): Orientation of the focal spot as a quaternion. Default is [0., 0., 0, 1.]. + voltage_kv (float): The voltage in kilovolts. Default is 100. + current_ua (float): The current in microamperes. Default is 100. + exposure_time_ms (float): The exposure time in milliseconds. Default is 1000. + """ + super().__init__(focal_spot_mm, detector_postion_mm, detector_orientation_quad, frame_id, focal_spot_orientation_quad) self.image = image.astype(np.uint16) self.detector_heigth_mm = detector_heigth_mm @@ -25,6 +54,13 @@ class PyProjection(PyProjectionGeometry): @classmethod def dummy(cls): + """ + Creates a dummy instance of PyProjection for testing. + + Returns: + PyProjection: A dummy instance with default values. + """ + return cls(np.array([0., 100., 0]), np.array([0., -100., 0]), np.array([1., 0., 0, 1.]), @@ -32,6 +68,12 @@ class PyProjection(PyProjectionGeometry): 10., 10.) def __str__(self) -> str: + """ + Returns a string representation of the PyProjection instance. + + Returns: + str: The string representation of the projection. + """ print_str = f'---\n' print_str += f'PyProjection\n' print_str += f'--- Projection Geometry:\n' @@ -42,9 +84,17 @@ class PyProjection(PyProjectionGeometry): print_str += f'---\n\n' return print_str - @classmethod def from_message(cls, msg: Projection): + """ + Creates an instance of PyProjection from a ROS message. + + Args: + msg (Projection): The ROS message containing projection data. + + Returns: + PyProjection: An instance initialized from the ROS message. + """ focal_spot_mm = np.array([msg.projection_geometry.focal_spot_postion_mm.x, msg.projection_geometry.focal_spot_postion_mm.y, msg.projection_geometry.focal_spot_postion_mm.z,]) @@ -79,6 +129,12 @@ class PyProjection(PyProjectionGeometry): voltage_kv, current_ua, exposure_time_ms) def as_message(self) -> Projection: + """ + Converts the PyProjection instance to a ROS message. + + Returns: + Projection: The ROS message representing the projection. + """ message = Projection() projection_geometry = ProjectionGeometry() diff --git a/rq_controller/common/projection_geometry.py b/rq_controller/common/projection_geometry.py index 8548dcd31f59430aca4c799197a3a15565092dc4..61bc50099cc3dba82573915d7c58003f93ddbbb0 100644 --- a/rq_controller/common/projection_geometry.py +++ b/rq_controller/common/projection_geometry.py @@ -5,10 +5,32 @@ import numpy as np from rq_interfaces.msg import ProjectionGeometry, Projection class PyProjectionGeometry(): + """ + Represents the geometry of a projection including the focal spot and detector position and orientation. + + Attributes: + focal_spot_mm (np.ndarray): Position of the focal spot in millimeters. + detector_postion_mm (np.ndarray): Position of the detector in millimeters. + detector_orientation_quad (np.ndarray): Orientation of the detector as a quaternion. + focal_spot_orientation_quad (np.ndarray): Orientation of the focal spot as a quaternion. + frame_id (str): Frame ID for the projection geometry. + """ + def __init__(self, focal_spot_mm: np.ndarray, detector_postion_mm: np.ndarray, detector_orientation_quad: np.ndarray, frame_id: str = 'object', focal_spot_orientation_quad: np.ndarray = np.array([0., 0., 0, 1.]) ) -> None: + """ + Initializes a PyProjectionGeometry instance. + + Args: + focal_spot_mm (np.ndarray): Position of the focal spot in millimeters. + detector_postion_mm (np.ndarray): Position of the detector in millimeters. + detector_orientation_quad (np.ndarray): Orientation of the detector as a quaternion. + frame_id (str): Frame ID for the projection geometry. Default is 'object'. + focal_spot_orientation_quad (np.ndarray): Orientation of the focal spot as a quaternion. + """ + self.focal_spot_mm = focal_spot_mm self.detector_postion_mm = detector_postion_mm self.focal_spot_orientation_quad = focal_spot_orientation_quad @@ -17,12 +39,27 @@ class PyProjectionGeometry(): @classmethod def dummy(cls): + """ + Creates a dummy instance of PyProjectionGeometry for testing. + + Returns: + PyProjectionGeometry: A dummy instance with default values. + """ return cls(np.array([1., 0., 0]), np.array([-1., 0., 0]), np.array([0., 0., 0, 1.])) @classmethod def from_message(cls, msg: ProjectionGeometry): + """ + Creates an instance of PyProjectionGeometry from a ROS message. + + Args: + msg (ProjectionGeometry): The ROS message containing projection geometry data. + + Returns: + PyProjectionGeometry: An instance initialized from the ROS message. + """ focal_spot_mm = np.array([msg.focal_spot_postion_mm.x, msg.focal_spot_postion_mm.y, msg.focal_spot_postion_mm.z,]) @@ -46,6 +83,12 @@ class PyProjectionGeometry(): return cls(focal_spot_mm, detector_center_mm, detector_orientation_quad, frame_id, focal_spot_orientation) def as_message(self) -> ProjectionGeometry: + """ + Converts the PyProjectionGeometry instance to a ROS message. + + Returns: + ProjectionGeometry: The ROS message representing the projection geometry. + """ message = ProjectionGeometry() message.focal_spot_postion_mm.x = self.focal_spot_mm[0] diff --git a/rq_controller/common/region_of_intrest.py b/rq_controller/common/region_of_intrest.py index f067df7eaf4ff5e1c5b48367e6e3203f138fedf8..4a71fda142cd0724897577ca7ffbed0c60c98e71 100644 --- a/rq_controller/common/region_of_intrest.py +++ b/rq_controller/common/region_of_intrest.py @@ -7,8 +7,28 @@ from visualization_msgs.msg import Marker class PyRegionOfIntrest(): + """ + Represents a region of interest (ROI) with center points, dimensions, and resolution. + + Attributes: + center_points_mm (np.ndarray): Center points of the ROIs in millimeters. + dimensions_mm (np.ndarray): Dimensions of the ROIs in millimeters. + frame_id (str): Frame ID for the ROI. + resolution_mm (np.ndarray): Resolution of the ROIs in millimeters. + """ + def __init__(self, center_points_mm: np.ndarray, dimensions_mm: np.ndarray, frame_id: str = 'object', resolution_mm: np.ndarray = np.array([0.1, 0.1, 0.1])): + """ + Initializes a PyRegionOfIntrest instance. + + Args: + center_points_mm (np.ndarray): Center points of the ROIs in millimeters. + dimensions_mm (np.ndarray): Dimensions of the ROIs in millimeters. + frame_id (str): Frame ID for the ROI. Default is 'object'. + resolution_mm (np.ndarray): Resolution of the ROIs in millimeters. Default is [0.1, 0.1, 0.1]. + """ + self.center_points_mm = center_points_mm.reshape((-1, 3)) self.dimensions_mm = dimensions_mm.reshape((-1, 3)) self.frame_id = frame_id @@ -16,11 +36,28 @@ class PyRegionOfIntrest(): @classmethod def dummy(cls): + """ + Creates a dummy instance of PyRegionOfIntrest for testing. + + Returns: + PyRegionOfIntrest: A dummy instance with random values. + """ + return cls((np.random.random((3, )) - 0.5) * 20., np.random.random((3, )) * 10.) @classmethod def from_message(cls, msg: RegionOfIntrest): + """ + Creates an instance of PyRegionOfIntrest from a ROS message. + + Args: + msg (RegionOfIntrest): The ROS message containing ROI data. + + Returns: + PyRegionOfIntrest: An instance initialized from the ROS message. + """ + center_points_mm = list() dimensions_mm = list() @@ -45,14 +82,35 @@ class PyRegionOfIntrest(): @property def number_of_rois(self) -> int: + """ + Returns the number of regions of interest. + + Returns: + int: The number of ROIs. + """ + return self.center_points_mm.shape[0] @property def shape(self) -> tuple: + """ + Returns the shape of the ROI grid based on the dimensions and resolution. + + Returns: + tuple: The shape of the ROI grid. + """ + shape = self.dimensions_mm[0] // self.resolution_mm[0] return (int(shape[0]), int(shape[1]), int(shape[2])) def as_message(self) -> RegionOfIntrest: + """ + Converts the PyRegionOfIntrest instance to a ROS message. + + Returns: + RegionOfIntrest: The ROS message representing the ROI. + """ + message = RegionOfIntrest() roi_list = list() @@ -80,6 +138,16 @@ class PyRegionOfIntrest(): return message def get_grid(self, indice: int = 0) -> np.ndarray: + """ + Generates a grid of points within the ROI. + + Args: + indice (int): Index of the ROI. Default is 0. + + Returns: + np.ndarray: A grid of points within the ROI. + """ + start = self.center_points_mm[indice] - (self.dimensions_mm[indice] / 2.) end = self.center_points_mm[indice] + (self.dimensions_mm[indice] / 2.) @@ -97,6 +165,16 @@ class PyRegionOfIntrest(): @staticmethod def next_neighbor(grid_mm: np.ndarray, point_mm: np.ndarray) -> np.ndarray: + """ + Finds the nearest neighbor in the grid to a given point. + + Args: + grid_mm (np.ndarray): The grid of points. + point_mm (np.ndarray): The point to find the nearest neighbor for. + + Returns: + np.ndarray: The indices of the nearest neighbor in the grid. + """ x = grid_mm[:, 0, 0, 0] y = grid_mm[0, :, 0, 1] diff --git a/rq_controller/common/volume.py b/rq_controller/common/volume.py index 3ac025ed40081d0c09566ec3e223bf09c656a5dd..ac1b64568c12ed0a0688cf32bdcc7f72e971c618 100644 --- a/rq_controller/common/volume.py +++ b/rq_controller/common/volume.py @@ -11,18 +11,50 @@ import ros2_numpy class VOLUME_TYPES(IntEnum): + """ + Enum representing volume data types. + """ + UINT_16 = 0 UINT_8 = 1 class PyVolume(): + """ + Represents a volumetric dataset with an associated region of interest. + + Attributes: + roi (PyRegionOfIntrest): The region of interest associated with the volume. + array (ndarray): The volumetric data array. + data_typ (VOLUME_TYPES): The data type of the volume. + """ + def __init__(self, array: ndarray, roi: PyRegionOfIntrest, data_type: VOLUME_TYPES = ...): + """ + Initializes a PyVolume instance. + + Args: + array (ndarray): The volumetric data array. + roi (PyRegionOfIntrest): The region of interest associated with the volume. + data_type (VOLUME_TYPES): The data type of the volume. Default is VOLUME_TYPES.UINT_8. + """ + self.roi = roi self.array = array self.data_typ = data_type @staticmethod def get_data_type(volume_type: VOLUME_TYPES) -> np.dtype: + """ + Gets the numpy data type corresponding to the given volume type. + + Args: + volume_type (VOLUME_TYPES): The volume data type. + + Returns: + np.dtype: The corresponding numpy data type. + """ + if volume_type == VOLUME_TYPES.UINT_16: return np.uint16 elif volume_type == VOLUME_TYPES.UINT_8: @@ -32,6 +64,16 @@ class PyVolume(): @staticmethod def enum_to_numpify(volume_type: VOLUME_TYPES) -> str: + """ + Converts the volume type enum to a string for use in ROS2 numpy conversion. + + Args: + volume_type (VOLUME_TYPES): The volume data type. + + Returns: + str: The corresponding string for ROS2 numpy conversion. + """ + if volume_type == VOLUME_TYPES.UINT_16: return 'mono16' elif volume_type == VOLUME_TYPES.UINT_8: @@ -40,14 +82,31 @@ class PyVolume(): raise ValueError('Datatype is not implemented') @classmethod - def dummy(cls): + def dummy(cls) -> 'PyVolume': + """ + Creates a dummy instance of PyVolume for testing. + + Returns: + PyVolume: A dummy instance with random data. + """ + roi = PyRegionOfIntrest.dummy() array = np.random.randint(0, 255, size=roi.shape) data_type = VOLUME_TYPES.UINT_8 return cls(array, roi, data_type) @classmethod - def from_message(cls, msg: Volume): + def from_message(cls, msg: Volume) -> 'PyVolume': + """ + Creates an instance of PyVolume from a ROS message. + + Args: + msg (Volume): The ROS message containing volume data. + + Returns: + PyVolume: An instance initialized from the ROS message. + """ + roi: Marker = msg.grid.region_of_intrest_stack.markers[0] center_points_mm = np.array([ roi.pose.position.x, @@ -77,6 +136,13 @@ class PyVolume(): return cls(array, py_roi, data_typ) def as_message(self) -> Volume: + """ + Converts the PyVolume instance to a ROS message. + + Returns: + Volume: The ROS message representing the volume. + """ + message = Volume() message.datatype = self.data_typ @@ -93,6 +159,13 @@ class PyVolume(): @property def shape(self): + """ + Returns the shape of the volumetric data array. + + Returns: + tuple: The shape of the volume. + """ + return self.array.shape diff --git a/test/__pycache__/test_copyright.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_copyright.cpython-38-pytest-8.2.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0a8d69338629d501cf0136fa81e7f066901c8a9 Binary files /dev/null and b/test/__pycache__/test_copyright.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_flake8.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_flake8.cpython-38-pytest-8.2.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df33d1d7ce4cf030f75a9d7d2595cb2db409dcf0 Binary files /dev/null and b/test/__pycache__/test_flake8.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_pep257.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_pep257.cpython-38-pytest-8.2.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6355ad401a01ed7cd9a34112ca34b92937168cf Binary files /dev/null and b/test/__pycache__/test_pep257.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_projection.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_projection.cpython-38-pytest-8.2.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8413a88371ec487791bfb98a4279f60cf8b3edae Binary files /dev/null and b/test/__pycache__/test_projection.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_projection_geometry.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_projection_geometry.cpython-38-pytest-8.2.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daecbc751bbc7609edc20368ee278eb37f545dfe Binary files /dev/null and b/test/__pycache__/test_projection_geometry.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_region_of_intrest.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_region_of_intrest.cpython-38-pytest-8.2.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..469b57f673dd3002e757812d9c29a8ea600abac0 Binary files /dev/null and b/test/__pycache__/test_region_of_intrest.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/__pycache__/test_volume.cpython-38-pytest-8.2.2.pyc b/test/__pycache__/test_volume.cpython-38-pytest-8.2.2.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4329085a29c878eed2392f2e4b4eda540f9cb6f Binary files /dev/null and b/test/__pycache__/test_volume.cpython-38-pytest-8.2.2.pyc differ diff --git a/test/test_copyright.py b/test/test_copyright.py deleted file mode 100644 index 97a39196e84db97954341162a6d2e7f771d938c0..0000000000000000000000000000000000000000 --- a/test/test_copyright.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2015 Open Source Robotics Foundation, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ament_copyright.main import main -import pytest - - -# Remove the `skip` decorator once the source file(s) have a copyright header -@pytest.mark.skip(reason='No copyright header has been placed in the generated source file.') -@pytest.mark.copyright -@pytest.mark.linter -def test_copyright(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found errors' diff --git a/test/test_flake8.py b/test/test_flake8.py deleted file mode 100644 index 27ee1078ff077cc3a0fec75b7d023101a68164d1..0000000000000000000000000000000000000000 --- a/test/test_flake8.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2017 Open Source Robotics Foundation, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ament_flake8.main import main_with_errors -import pytest - - -@pytest.mark.flake8 -@pytest.mark.linter -def test_flake8(): - rc, errors = main_with_errors(argv=[]) - assert rc == 0, \ - 'Found %d code style errors / warnings:\n' % len(errors) + \ - '\n'.join(errors) diff --git a/test/test_pep257.py b/test/test_pep257.py deleted file mode 100644 index b234a3840f4c5bd38f043638c8622b8f240e1185..0000000000000000000000000000000000000000 --- a/test/test_pep257.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2015 Open Source Robotics Foundation, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ament_pep257.main import main -import pytest - - -@pytest.mark.linter -@pytest.mark.pep257 -def test_pep257(): - rc = main(argv=['.', 'test']) - assert rc == 0, 'Found code style errors / warnings' diff --git a/test/test_projection.py b/test/test_projection.py new file mode 100644 index 0000000000000000000000000000000000000000..06075e4a0d53943545fb6d2dd87c8dc33b5ae880 --- /dev/null +++ b/test/test_projection.py @@ -0,0 +1,134 @@ +import pytest +import numpy as np +from rq_controller.common import PyProjection # Adjust this import according to your module's path +from rq_interfaces.msg import Projection +from sensor_msgs.msg import Image +import ros2_numpy + +@pytest.fixture +def example_message(): + msg = Projection() + msg.projection_geometry.focal_spot_postion_mm.x = 10.0 + msg.projection_geometry.focal_spot_postion_mm.y = 20.0 + msg.projection_geometry.focal_spot_postion_mm.z = 30.0 + msg.projection_geometry.detector_postion_mm.x = 40.0 + msg.projection_geometry.detector_postion_mm.y = 50.0 + msg.projection_geometry.detector_postion_mm.z = 60.0 + msg.projection_geometry.detector_orientation_quad.x = 0.0 + msg.projection_geometry.detector_orientation_quad.y = 0.0 + msg.projection_geometry.detector_orientation_quad.z = 0.0 + msg.projection_geometry.detector_orientation_quad.w = 1.0 + msg.projection_geometry.focal_spot_orientation_quad.x = 0.0 + msg.projection_geometry.focal_spot_orientation_quad.y = 0.0 + msg.projection_geometry.focal_spot_orientation_quad.z = 0.0 + msg.projection_geometry.focal_spot_orientation_quad.w = 1.0 + msg.projection_geometry.header.frame_id = "test_frame" + + image_array = np.zeros((10, 10), dtype=np.uint16) + msg.image = ros2_numpy.msgify(Image, image_array, encoding='mono16') + + msg.detector_heigth_mm = 100.0 + msg.detector_width_mm = 200.0 + msg.voltage_kv = 120.0 + msg.current_ua = 150.0 + msg.exposure_time_ms = 500.0 + return msg + +def test_initialization(): + focal_spot_mm = np.array([1.0, 2.0, 3.0]) + detector_postion_mm = np.array([4.0, 5.0, 6.0]) + detector_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) + image = np.zeros((10, 10), dtype=np.uint16) + detector_heigth_mm = 10.0 + detector_width_mm = 20.0 + voltage_kv = 100.0 + current_ua = 200.0 + exposure_time_ms = 1000.0 + frame_id = "test_frame" + focal_spot_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) + + projection = PyProjection(focal_spot_mm, detector_postion_mm, detector_orientation_quad, image, detector_heigth_mm, detector_width_mm, frame_id, focal_spot_orientation_quad, voltage_kv, current_ua, exposure_time_ms) + + assert np.array_equal(projection.focal_spot_mm, focal_spot_mm) + assert np.array_equal(projection.detector_postion_mm, detector_postion_mm) + assert np.array_equal(projection.detector_orientation_quad, detector_orientation_quad) + assert np.array_equal(projection.focal_spot_orientation_quad, focal_spot_orientation_quad) + assert np.array_equal(projection.image, image) + assert projection.detector_heigth_mm == detector_heigth_mm + assert projection.detector_width_mm == detector_width_mm + assert projection.voltage_kv == voltage_kv + assert projection.current_ua == current_ua + assert projection.exposure_time_ms == exposure_time_ms + assert projection.frame_id == frame_id + +def test_dummy_method(): + projection = PyProjection.dummy() + + assert np.array_equal(projection.focal_spot_mm, np.array([0.0, 100.0, 0.0])) + assert np.array_equal(projection.detector_postion_mm, np.array([0.0, -100.0, 0.0])) + assert np.array_equal(projection.detector_orientation_quad, np.array([1.0, 0.0, 0.0, 1.0])) + assert np.array_equal(projection.image, np.zeros((10, 10), dtype=np.uint16)) + assert projection.detector_heigth_mm == 10.0 + assert projection.detector_width_mm == 10.0 + assert projection.frame_id == "object" + +def test_from_message(example_message): + projection = PyProjection.from_message(example_message) + + assert np.array_equal(projection.focal_spot_mm, np.array([10.0, 20.0, 30.0])) + assert np.array_equal(projection.detector_postion_mm, np.array([40.0, 50.0, 60.0])) + assert np.array_equal(projection.detector_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0])) + assert np.array_equal(projection.focal_spot_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0])) + assert np.array_equal(projection.image, np.zeros((10, 10), dtype=np.uint16)) + assert projection.detector_heigth_mm == 100.0 + assert projection.detector_width_mm == 200.0 + assert projection.voltage_kv == 120.0 + assert projection.current_ua == 150.0 + assert projection.exposure_time_ms == 500.0 + assert projection.frame_id == "test_frame" + +def test_as_message(example_message): + projection = PyProjection.from_message(example_message) + msg = projection.as_message() + + assert msg.projection_geometry.focal_spot_postion_mm.x == 10.0 + assert msg.projection_geometry.focal_spot_postion_mm.y == 20.0 + assert msg.projection_geometry.focal_spot_postion_mm.z == 30.0 + assert msg.projection_geometry.detector_postion_mm.x == 40.0 + assert msg.projection_geometry.detector_postion_mm.y == 50.0 + assert msg.projection_geometry.detector_postion_mm.z == 60.0 + assert msg.projection_geometry.detector_orientation_quad.x == 0.0 + assert msg.projection_geometry.detector_orientation_quad.y == 0.0 + assert msg.projection_geometry.detector_orientation_quad.z == 0.0 + assert msg.projection_geometry.detector_orientation_quad.w == 1.0 + assert msg.projection_geometry.focal_spot_orientation_quad.x == 0.0 + assert msg.projection_geometry.focal_spot_orientation_quad.y == 0.0 + assert msg.projection_geometry.focal_spot_orientation_quad.z == 0.0 + assert msg.projection_geometry.focal_spot_orientation_quad.w == 1.0 + assert msg.projection_geometry.header.frame_id == "test_frame" + assert np.array_equal(ros2_numpy.numpify(msg.image), np.zeros((10, 10), dtype=np.uint16)) + assert msg.detector_heigth_mm == 100.0 + assert msg.detector_width_mm == 200.0 + assert msg.voltage_kv == 120.0 + assert msg.current_ua == 150.0 + assert msg.exposure_time_ms == 500.0 + +def test_properties(): + focal_spot_mm = np.array([1.0, 2.0, 3.0]) + detector_postion_mm = np.array([4.0, 5.0, 6.0]) + detector_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) + image = np.zeros((20, 30), dtype=np.uint16) + detector_heigth_mm = 100.0 + detector_width_mm = 200.0 + voltage_kv = 100.0 + current_ua = 200.0 + exposure_time_ms = 1000.0 + frame_id = "test_frame" + focal_spot_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) + + projection = PyProjection(focal_spot_mm, detector_postion_mm, detector_orientation_quad, image, detector_heigth_mm, detector_width_mm, frame_id, focal_spot_orientation_quad, voltage_kv, current_ua, exposure_time_ms) + + assert projection.detector_heigth_px == 20 + assert projection.detector_width_px == 30 + assert projection.pixel_pitch_x_mm == 200.0 / 30 + assert projection.pixel_pitch_y_mm == 100.0 / 20 \ No newline at end of file diff --git a/test/test_projection_geometry.py b/test/test_projection_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4c44de107de144cc90a0abda6004c6ded44014 --- /dev/null +++ b/test/test_projection_geometry.py @@ -0,0 +1,78 @@ +import pytest +import numpy as np +from rq_interfaces.msg import ProjectionGeometry +from rq_controller.common import PyProjectionGeometry + + +@pytest.fixture +def example_message(): + msg = ProjectionGeometry() + msg.focal_spot_postion_mm.x = 10. + msg.focal_spot_postion_mm.y = 20. + msg.focal_spot_postion_mm.z = 30. + msg.detector_postion_mm.x = 40. + msg.detector_postion_mm.y = 50. + msg.detector_postion_mm.z = 60. + msg.detector_orientation_quad.x = 0. + msg.detector_orientation_quad.y = 0. + msg.detector_orientation_quad.z = 0. + msg.detector_orientation_quad.w = 1. + msg.focal_spot_orientation_quad.x = 0. + msg.focal_spot_orientation_quad.y = 0. + msg.focal_spot_orientation_quad.z = 0. + msg.focal_spot_orientation_quad.w = 1. + msg.header.frame_id = "test_frame" + return msg + +def test_initialization(): + focal_spot_mm = np.array([1.0, 2.0, 3.0]) + detector_postion_mm = np.array([4.0, 5.0, 6.0]) + detector_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) + frame_id = "test_frame" + focal_spot_orientation_quad = np.array([0.0, 0.0, 0.0, 1.0]) + + geometry = PyProjectionGeometry(focal_spot_mm, detector_postion_mm, detector_orientation_quad, frame_id, focal_spot_orientation_quad) + + assert np.array_equal(geometry.focal_spot_mm, focal_spot_mm) + assert np.array_equal(geometry.detector_postion_mm, detector_postion_mm) + assert np.array_equal(geometry.detector_orientation_quad, detector_orientation_quad) + assert np.array_equal(geometry.focal_spot_orientation_quad, focal_spot_orientation_quad) + assert geometry.frame_id == frame_id + +def test_dummy_method(): + geometry = PyProjectionGeometry.dummy() + + assert np.array_equal(geometry.focal_spot_mm, np.array([1.0, 0.0, 0.0])) + assert np.array_equal(geometry.detector_postion_mm, np.array([-1.0, 0.0, 0.0])) + assert np.array_equal(geometry.detector_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0])) + assert geometry.frame_id == "object" + assert np.array_equal(geometry.focal_spot_orientation_quad, np.array([0.0, 0.0, 0.0, 1.0])) + +def test_from_message(example_message): + geometry = PyProjectionGeometry.from_message(example_message) + + assert np.array_equal(geometry.focal_spot_mm, np.array([10, 20, 30])) + assert np.array_equal(geometry.detector_postion_mm, np.array([40, 50, 60])) + assert np.array_equal(geometry.detector_orientation_quad, np.array([0, 0, 0, 1])) + assert np.array_equal(geometry.focal_spot_orientation_quad, np.array([0, 0, 0, 1])) + assert geometry.frame_id == "test_frame" + +def test_as_message(example_message): + geometry = PyProjectionGeometry.from_message(example_message) + msg = geometry.as_message() + + assert msg.focal_spot_postion_mm.x == 10 + assert msg.focal_spot_postion_mm.y == 20 + assert msg.focal_spot_postion_mm.z == 30 + assert msg.detector_postion_mm.x == 40 + assert msg.detector_postion_mm.y == 50 + assert msg.detector_postion_mm.z == 60 + assert msg.detector_orientation_quad.x == 0 + assert msg.detector_orientation_quad.y == 0 + assert msg.detector_orientation_quad.z == 0 + assert msg.detector_orientation_quad.w == 1 + assert msg.focal_spot_orientation_quad.x == 0 + assert msg.focal_spot_orientation_quad.y == 0 + assert msg.focal_spot_orientation_quad.z == 0 + assert msg.focal_spot_orientation_quad.w == 1 + assert msg.header.frame_id == "test_frame" diff --git a/test/test_region_of_intrest.py b/test/test_region_of_intrest.py new file mode 100644 index 0000000000000000000000000000000000000000..678f0528d467c6a9c4f975df7ae10b446e73e067 --- /dev/null +++ b/test/test_region_of_intrest.py @@ -0,0 +1,100 @@ +import pytest +import numpy as np +from rq_controller.common import PyRegionOfIntrest # Adjust this import according to your module's path +from rq_interfaces.msg import RegionOfIntrest +from visualization_msgs.msg import Marker + +@pytest.fixture +def example_message(): + msg = RegionOfIntrest() + + marker = Marker() + marker.pose.position.x = 10.0 + marker.pose.position.y = 20.0 + marker.pose.position.z = 30.0 + marker.scale.x = 40.0 + marker.scale.y = 50.0 + marker.scale.z = 60.0 + marker.header.frame_id = "test_frame" + + msg.region_of_intrest_stack.markers.append(marker) + + msg.resolution.x = 0.1 + msg.resolution.y = 0.1 + msg.resolution.z = 0.1 + + return msg + +def test_initialization(): + center_points_mm = np.array([[1.0, 2.0, 3.0]]) + dimensions_mm = np.array([[4.0, 5.0, 6.0]]) + resolution_mm = np.array([[0.1, 0.1, 0.1]]) + frame_id = "test_frame" + + roi = PyRegionOfIntrest(center_points_mm, dimensions_mm, frame_id, resolution_mm) + + assert np.array_equal(roi.center_points_mm, center_points_mm) + assert np.array_equal(roi.dimensions_mm, dimensions_mm) + assert np.array_equal(roi.resolution_mm, resolution_mm) + assert roi.frame_id == frame_id + +def test_dummy_method(): + roi = PyRegionOfIntrest.dummy() + + assert roi.center_points_mm.shape == (1, 3) + assert roi.dimensions_mm.shape == (1, 3) + assert roi.frame_id == "object" + +def test_from_message(example_message): + roi = PyRegionOfIntrest.from_message(example_message) + + assert np.array_equal(roi.center_points_mm, np.array([[10.0, 20.0, 30.0]])) + assert np.array_equal(roi.dimensions_mm, np.array([[40.0, 50.0, 60.0]])) + assert np.array_equal(roi.resolution_mm, np.array([[0.1, 0.1, 0.1]])) + assert roi.frame_id == "test_frame" + +def test_as_message(example_message): + roi = PyRegionOfIntrest.from_message(example_message) + msg = roi.as_message() + + assert len(msg.region_of_intrest_stack.markers) == 1 + marker = msg.region_of_intrest_stack.markers[0] + + assert marker.pose.position.x == 10.0 + assert marker.pose.position.y == 20.0 + assert marker.pose.position.z == 30.0 + assert marker.scale.x == 40.0 + assert marker.scale.y == 50.0 + assert marker.scale.z == 60.0 + assert marker.header.frame_id == "test_frame" + + assert msg.resolution.x == 0.1 + assert msg.resolution.y == 0.1 + assert msg.resolution.z == 0.1 + +def test_number_of_rois(example_message): + roi = PyRegionOfIntrest.from_message(example_message) + assert roi.number_of_rois == 1 + +def test_shape(example_message): + roi = PyRegionOfIntrest.from_message(example_message) + assert roi.shape == (399, 499, 599) + +def test_get_grid(example_message): + roi = PyRegionOfIntrest.from_message(example_message) + grid = roi.get_grid(0) + + assert grid.shape == (399, 499, 599, 3) + assert np.array_equal(grid[0, 0, 0], np.array([10.0 - 20.0, 20.0 - 25.0, 30.0 - 30.0])) + assert np.array_equal(grid[-1, -1, -1], np.array([10.0 + 20.0, 20.0 + 25.0, 30.0 + 30.0])) + +def test_next_neighbor(): + grid_mm = np.zeros((20, 30, 40, 3)) + grid_mm[:, :, :, 0] = np.linspace(0, 20, 20).reshape(-1, 1, 1) + grid_mm[:, :, :, 1] = np.linspace(0, 30, 30).reshape(1, -1, 1) + grid_mm[:, :, :, 2] = np.linspace(0, 40, 40).reshape(1, 1, -1) + + point_mm = np.array([11.0, 16.0, 21.0]) + index = PyRegionOfIntrest.next_neighbor(grid_mm, point_mm) + + assert np.array_equal(index, [10, 15, 20]) diff --git a/test/test_volume.py b/test/test_volume.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a8ff8e78facdcf65040525438f3b1975ae49df --- /dev/null +++ b/test/test_volume.py @@ -0,0 +1,98 @@ +import pytest +import numpy as np +from rq_controller.common import PyVolume, PyRegionOfIntrest # Adjust this import according to your module's path +from rq_controller.common.volume import VOLUME_TYPES +from rq_interfaces.msg import Volume +from visualization_msgs.msg import Marker +from sensor_msgs.msg import Image +import ros2_numpy + + +@pytest.fixture +def example_message(): + msg = Volume() + + marker = Marker() + marker.pose.position.x = 10.0 + marker.pose.position.y = 20.0 + marker.pose.position.z = 30.0 + marker.scale.x = 40.0 + marker.scale.y = 50.0 + marker.scale.z = 60.0 + marker.header.frame_id = "test_frame" + + msg.grid.region_of_intrest_stack.markers.append(marker) + + msg.grid.resolution.x = 1. + msg.grid.resolution.y = 1. + msg.grid.resolution.z = 1. + + msg.datatype = VOLUME_TYPES.UINT_16 + + shape = (40, 50, 60) + slices = np.random.randint(0, 65535, size=shape, dtype=np.uint16) + for i in range(shape[2]): + image_msg = ros2_numpy.msgify(Image, slices[:, :, i], encoding='mono16') + msg.slices.append(image_msg) + + return msg + + +def test_initialization(): + roi = PyRegionOfIntrest.dummy() + array = np.random.randint(0, 255, size=roi.shape, dtype=np.uint8) + volume = PyVolume(array, roi, VOLUME_TYPES.UINT_8) + + assert np.array_equal(volume.array, array) + assert volume.roi == roi + assert volume.data_typ == VOLUME_TYPES.UINT_8 + + +def test_dummy_method(): + volume = PyVolume.dummy() + + assert volume.array.shape == volume.roi.shape + assert volume.data_typ == VOLUME_TYPES.UINT_8 + + +def test_from_message(example_message): + volume = PyVolume.from_message(example_message) + + assert volume.array.shape == (40, 50, 60) + assert volume.data_typ == VOLUME_TYPES.UINT_16 + assert volume.roi.frame_id == "test_frame" + + +def test_as_message(example_message): + volume = PyVolume.from_message(example_message) + msg = volume.as_message() + + assert msg.datatype == VOLUME_TYPES.UINT_16 + assert len(msg.slices) == 60 + + for i, slice_msg in enumerate(msg.slices): + slice_array = ros2_numpy.numpify(slice_msg) + assert np.array_equal(slice_array, volume.array[:, :, i]) + + +def test_get_data_type(): + assert PyVolume.get_data_type(VOLUME_TYPES.UINT_16) == np.uint16 + assert PyVolume.get_data_type(VOLUME_TYPES.UINT_8) == np.uint8 + with pytest.raises(ValueError): + PyVolume.get_data_type(-1) + + +def test_enum_to_numpify(): + assert PyVolume.enum_to_numpify(VOLUME_TYPES.UINT_16) == 'mono16' + assert PyVolume.enum_to_numpify(VOLUME_TYPES.UINT_8) == 'mono8' + with pytest.raises(ValueError): + PyVolume.enum_to_numpify(-1) + + +def test_shape(): + roi = PyRegionOfIntrest.dummy() + array = np.random.randint(0, 255, size=roi.shape, dtype=np.uint8) + volume = PyVolume(array, roi, VOLUME_TYPES.UINT_8) + + assert volume.shape == array.shape +