import artistlib.hardware

import rclpy.publisher
from rq_interfaces.srv import ReachabilityMapCheck, AquireProjection, ReachabilityCheck
from rq_interfaces.msg import Volume
from ..base_hardware_service import BaseHardwareInterfaceService
from rq_controller.common import PyProjection, PyRegionOfIntrest, PyVolume

try:
    import artistlib
    from artistlib import utility as artist_utility
except ModuleNotFoundError:
    raise ModuleNotFoundError("aRTist API not found")

import ros2_numpy
import numpy as np

from sensor_msgs.msg import Image
from shape_msgs.msg import Mesh, MeshTriangle
from geometry_msgs.msg import Point

import open3d as o3d

import rclpy

from pathfinding3d.core.diagonal_movement import DiagonalMovement
from pathfinding3d.core.grid import Grid
from pathfinding3d.finder.theta_star import ThetaStarFinder


class ArtistHardwareInterface(BaseHardwareInterfaceService):
    """
    ArtistHardwareInterface is a specialized hardware interface service node for interacting with the aRTist simulation software.

    Attributes:
        interface_type (str): The type of interface.
        collission_mesh (o3d.geometry.TriangleMesh): The collision mesh used for reachability checks.
        collission_voxel_ocu_map (PyVolume): The voxel occupancy map for collision detection.
        source_position_grid (np.ndarray): The grid of source positions for the reachability map.
    """

    interface_type: str = "aRTist"
    collission_mesh: o3d.geometry.TriangleMesh = None
    collission_voxel_ocu_map: PyVolume = None
    source_position_grid: np.ndarray = None

    def __init__(self, node_name: str = "rq_hardware_interface_service"):
        """
        Initialize the ArtistHardwareInterface node, create publishers and subscribers, and set up the aRTist API.

        :param node_name: The name of the node. Defaults to 'rq_hardware_interface_service'.
        """
        super().__init__(node_name)
        self.api = artistlib.API()
        self.source = artistlib.hardware.XraySource(self.api)
        self.detector = artistlib.hardware.XrayDetector(self.api)
        self.mesh_publisher = self.create_publisher(
            Mesh, "rq_artist_collision_mesh_mm", 10
        )
        self.mesh_subscriper = self.create_subscription(
            Mesh, "rq_artist_collision_mesh_mm", self.set_collission_mesh_callback, 10
        )

    def aquire_projection_callback(
        self, request: AquireProjection.Request, response: AquireProjection.Response
    ):
        """
        Handle the aquire_projection service request by setting the X-ray source and detector positions and acquiring a projection image.

        :param request: The service request containing the scan pose information.
        :param response: The service response to be filled with the acquired projection.
        :return: The updated service response with the projection image.
        """
        projection_geometry = PyProjection.from_message(request.scan_pose)

        # Move source
        self.api.rotate_from_quat("S", projection_geometry.focal_spot_orientation_quad)
        self.api.translate(
            "S",
            request.scan_pose.projection_geometry.focal_spot_postion_mm.x,
            request.scan_pose.projection_geometry.focal_spot_postion_mm.y,
            request.scan_pose.projection_geometry.focal_spot_postion_mm.z,
        )

        # Move detector
        self.api.rotate_from_quat("D", projection_geometry.detector_orientation_quad)
        self.api.translate(
            "D",
            request.scan_pose.projection_geometry.detector_postion_mm.x,
            request.scan_pose.projection_geometry.detector_postion_mm.y,
            request.scan_pose.projection_geometry.detector_postion_mm.z,
        )

        if not np.isclose(
            self.source.voltage_kv, request.scan_pose.voltage_kv, atol=0.1
        ):
            # Check if voltage is different to avoid calculation of source spectrum
            self.source.voltage_kv = request.scan_pose.voltage_kv

        self.source.exposure_ma = request.scan_pose.current_ua

        projection_path = self.projection_tempfolder / "projection.tif"
        self.api.save_image(projection_path)
        projection_image = artist_utility.load_projection(
            projection_path, load_projection_geometry=False
        )[0]

        response.projection = request.scan_pose
        response.projection.image = ros2_numpy.msgify(
            Image, projection_image.astype(np.uint16), "mono16"
        )

        self.projection_publisher.publish(
            ros2_numpy.msgify(Image, projection_image.astype(np.uint16), "mono16")
        )
        self.projection_geometry_publisher.publish(
            response.projection.projection_geometry
        )

        self.set_source_pose(
            projection_geometry.focal_spot_mm / 1000.0,
            projection_geometry.focal_spot_orientation_quad,
        )
        self.set_detector_pose(
            projection_geometry.detector_postion_mm / 1000.0,
            projection_geometry.detector_orientation_quad,
        )

        return response

    def check_reachability_callback(
        self, request: ReachabilityCheck.Request, response: ReachabilityCheck.Response
    ):
        """
        Handle the check_reachability service request by verifying if a scan pose is reachable.

        :param request: The service request containing the scan pose to check.
        :param response: The service response to be filled with the reachability result.
        :return: The updated service response with the reachability status.
        """
        self.get_logger().info("Pose to check ...")

        if self.collission_voxel_ocu_map is None:
            response.status.success = False
            response.status.status_message = "Error: There is no collision voxel grid. Use the service `check_reachability_map`"
            return response

        # Include some logic
        scan_pose = PyProjection.from_message(request.scan_pose)

        end_position = self.collission_voxel_ocu_map.roi.next_neighbor(
            self.source_position_grid, scan_pose.focal_spot_mm
        )
        end_position_check = self.collission_voxel_ocu_map.array[
            end_position[0], end_position[1], end_position[2]
        ]

        if end_position_check == 0:
            response.checked_scan_pose.scan_pose = request.scan_pose
            response.checked_scan_pose.reachable = False
            response.checked_scan_pose.cost = 1000.0
            self.get_logger().info("End can not be reached ...")
            return response

        if self.last_projection is None:
            response.checked_scan_pose.scan_pose = request.scan_pose
            response.checked_scan_pose.reachable = True
            response.checked_scan_pose.cost = 0.0
            self.get_logger().info("First path! ...")
            return response

        start_position = self.collission_voxel_ocu_map.roi.next_neighbor(
            self.source_position_grid, self.last_projection.focal_spot_mm
        )
        start_position_check = self.collission_voxel_ocu_map.array[
            start_position[0], start_position[1], start_position[2]
        ]

        if start_position_check == 0:
            response.checked_scan_pose.scan_pose = request.scan_pose
            response.checked_scan_pose.reachable = False
            response.checked_scan_pose.cost = 1000.0
            self.get_logger().info("Start can not be reached ...")
            return response

        path_cost = self.compute_path(start_position, end_position)

        response.checked_scan_pose.scan_pose = request.scan_pose
        response.checked_scan_pose.reachable = True
        response.checked_scan_pose.cost = float(path_cost)

        self.reachabilty_publisher.publish(response.checked_scan_pose)
        self.get_logger().info(f"Path can be reached. Cost: {path_cost}")
        return response

    def check_reachability_map_callback(
        self,
        request: ReachabilityMapCheck.Request,
        response: ReachabilityMapCheck.Response,
    ):
        """
        Handle the check_reachability_map service request by generating a voxel occupancy map based on the collision mesh.

        :param request: The service request containing the region of interest and other parameters.
        :param response: The service response to be filled with the voxel occupancy map.
        :return: The updated service response with the voxel occupancy map.
        """
        if self.collission_mesh is None:
            response.status.success = False
            response.status.status_message = "Error: There is no collision mesh set. Use the topic `rq_artist_collision_mesh_mm`"
            return response

        roi = PyRegionOfIntrest.from_message(request.region_of_intrest)
        source_positions = roi.get_grid()
        center = np.array(
            [request.center.x, request.center.y, request.center.z]
        ).reshape((-1, 3))
        directions = source_positions - center
        directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
        detector_positions = directions * request.fdd_mm

        source_check = self.compute_if_points_inside_mesh(source_positions)
        detector_check = self.compute_if_points_inside_mesh(detector_positions)
        check: np.ndarray = np.logical_or(source_check, detector_check)

        check[0, :, :] = 0
        check[-1:, :, :] = 0

        check[:, 0, :] = 0
        check[:, -1:, :] = 0

        check[:, :, 0] = 0
        check[:, :, -1:] = 0

        self.collission_voxel_ocu_map = PyVolume(
            check.astype(np.uint8), roi, Volume.TYPE_UINT8
        )
        self.source_position_grid = self.collission_voxel_ocu_map.roi.get_grid()

        response.ocupation_map = self.collission_voxel_ocu_map.as_message()
        response.status.success = True
        self.get_logger().info("Added collision voxel grid ...")

        return response

    def set_collission_mesh_callback(self, msg: Mesh):
        """
        Handle the setting of a new collision mesh by converting the received Mesh message to an Open3D mesh.

        :param msg: The received Mesh message containing the new collision mesh.
        """
        self.get_logger().info("... set new collision object")
        self.collission_voxel_ocu_map = None
        self.source_position_grid = None

        o3d_mesh = o3d.geometry.TriangleMesh()
        vertices = np.array([[vertex.x, vertex.y, vertex.z] for vertex in msg.vertices])
        o3d_mesh.vertices = o3d.utility.Vector3dVector(vertices)
        triangles = np.array(
            [
                [
                    triangle.vertex_indices[0],
                    triangle.vertex_indices[1],
                    triangle.vertex_indices[2],
                ]
                for triangle in msg.triangles
            ]
        )
        o3d_mesh.triangles = o3d.utility.Vector3iVector(triangles)

        self.collission_mesh = o3d_mesh
        self.get_logger().info("Added collision object ...")

    @staticmethod
    def from_open3d_to_message(mesh: o3d.geometry.TriangleMesh) -> Mesh:
        """
        Convert an Open3D TriangleMesh to a ROS Mesh message.

        :param mesh: The Open3D TriangleMesh to convert.
        :return: The corresponding ROS Mesh message.
        """
        shape_mesh = Mesh()

        vertices = np.asarray(mesh.vertices)

        for vertex in vertices:
            point = Point()
            point.x = vertex[0]
            point.y = vertex[1]
            point.z = vertex[2]
            shape_mesh.vertices.append(point)

        triangles = np.asarray(mesh.triangles)

        for triangle in triangles:
            mesh_triangle = MeshTriangle()
            mesh_triangle.vertex_indices = [
                int(triangle[0]),
                int(triangle[1]),
                int(triangle[2]),
            ]
            shape_mesh.triangles.append(mesh_triangle)

        return shape_mesh

    def compute_if_points_inside_mesh(self, points: np.ndarray) -> np.ndarray:
        """
        Determine if points are inside the collision mesh using ray casting.

        :param points: An array of points to check.
        :return: A boolean array where True indicates the point is outside the mesh.
        """
        point_shape = points.shape
        points = np.float32(points)
        rays = np.concatenate((points, np.ones_like(points)), -1)

        scene = o3d.t.geometry.RaycastingScene()
        scene.add_triangles(
            np.float32(self.collission_mesh.vertices),
            np.uint32(self.collission_mesh.triangles),
        )
        intersection_counts = scene.count_intersections(rays).numpy()
        is_outside = intersection_counts % 2 == 0
        return is_outside.reshape((point_shape[0], point_shape[1], point_shape[2], 1))

    def compute_path(self, start_indices, end_indices) -> float:
        """
        Compute the path cost between two points in the voxel grid using the Theta* algorithm.

        :param start_indices: The starting indices in the voxel grid.
        :param end_indices: The ending indices in the voxel grid.
        :return: The cost of the computed path.
        """
        grid = Grid(matrix=self.collission_voxel_ocu_map.array)
        start = grid.node(start_indices[0], start_indices[1], start_indices[2])
        end = grid.node(end_indices[0], end_indices[1], end_indices[2])

        finder = ThetaStarFinder(diagonal_movement=DiagonalMovement.always)
        path, runs = finder.find_path(start, end, grid)
        path = [p.identifier for p in path]

        self.get_logger().info(f"Operations: {runs} path length {len(path)}")
        grid.cleanup()
        return self.calculate_path_cost(path)

    def calculate_path_cost(self, path, z_weight: float = 3.0) -> float:
        """
        Calculate the cost of a path based on the Euclidean distance between points, with an optional weighting for the Z-axis.

        :param path: The path as a list of points.
        :param z_weight: The weighting factor for the Z-axis. Defaults to 3.0.
        :return: The calculated path cost.
        """
        cost = 0
        for pt, pt_next in zip(path[:-1], path[1:]):
            dx, dy, dz = pt_next[0] - pt[0], pt_next[1] - pt[1], pt_next[2] - pt[2]

            dx = self.collission_voxel_ocu_map.roi.resolution_mm[0] * dx
            dy = self.collission_voxel_ocu_map.roi.resolution_mm[1] * dy
            dz = self.collission_voxel_ocu_map.roi.resolution_mm[2] * dz * z_weight

            cost += (dx**2 + dy**2 + dz**2) ** 0.5
        return cost


def main(args=None):
    rclpy.init(args=args)

    minimal_service = ArtistHardwareInterface()

    rclpy.spin(minimal_service)
    rclpy.shutdown()


if __name__ == "__main__":
    main()