import artistlib.hardware

import rclpy.publisher
from rq_interfaces.srv import ReachabilityMapCheck
from rq_interfaces.msg import ServiceStatus, Volume
from ..base_hardware_service import BaseHardwareInterfaceService, AquireProjection, ReachabilityCheck, ReachabilityMapCheck
from rq_controller.common import PyProjection, PyRegionOfIntrest, PyVolume

try:
    import artistlib
    from artistlib import utility as artist_utility
except ModuleNotFoundError:
    raise ModuleNotFoundError('aRTist API not found')

import ros2_numpy
import numpy as np
from pathlib import Path

from sensor_msgs.msg import Image
from shape_msgs.msg import Mesh, MeshTriangle
from geometry_msgs.msg import Point

import open3d as o3d

import rclpy

from pathfinding3d.core.diagonal_movement import DiagonalMovement
from pathfinding3d.core.grid import Grid
from pathfinding3d.finder.a_star import AStarFinder
from pathfinding3d.finder.theta_star import ThetaStarFinder


class ArtistHardwareInterface(BaseHardwareInterfaceService):
    projection_tempfolder: Path = Path(r'C:\dev\rq_workflow\src\rq_hardware\temp')
    interface_type: str = 'aRTist'
    collission_mesh: o3d.geometry.TriangleMesh = None
    collission_voxel_ocu_map: PyVolume = None
    source_position_grid: np.ndarray = None


    def __init__(self, node_name: str = 'rq_hardware_interface_service'):

        super().__init__(node_name)
        self.api = artistlib.API()
        self.source = artistlib.hardware.XraySource(self.api)
        self.detector = artistlib.hardware.XrayDetector(self.api)
        self.mesh_publisher = self.create_publisher(Mesh, 'rq_artist_collision_mesh_mm', 10)
        self.mesh_subscriper = self.create_subscription(Mesh, 'rq_artist_collision_mesh_mm', 
                                                        self.set_collission_mesh_callback, 10)

    def aquire_projection_callback(self, 
                                   request: AquireProjection.Request,
                                   response: AquireProjection.Response):
        
        projection_geometry = PyProjection.from_message(request.scan_pose)
        
        # move source
        self.api.rotate_from_quat('S', projection_geometry.focal_spot_orientation_quad)
        self.api.translate('S', request.scan_pose.projection_geometry.focal_spot_postion_mm.x,
                                request.scan_pose.projection_geometry.focal_spot_postion_mm.y, 
                                request.scan_pose.projection_geometry.focal_spot_postion_mm.z)
        
        # move detector
        self.api.rotate_from_quat('D', projection_geometry.detector_orientation_quad)
        self.api.translate('D', request.scan_pose.projection_geometry.detector_postion_mm.x,
                                request.scan_pose.projection_geometry.detector_postion_mm.y, 
                                request.scan_pose.projection_geometry.detector_postion_mm.z)
        
        if not np.isclose(self.source.voltage_kv, request.scan_pose.voltage_kv, atol=0.1):
            # Check if voltage is diffrent to avoid calculation of source spectrum
            self.source.voltage_kv = request.scan_pose.voltage_kv

        self.source.exposure_ma = request.scan_pose.current_ua

        projection_path = self.projection_tempfolder / 'projection.tif'
        self.api.save_image(projection_path)
        projection_image = artist_utility.load_projection(projection_path, load_projection_geometry=False)[0]

        response.projection = request.scan_pose
        response.projection.image = ros2_numpy.msgify(Image, projection_image.astype(np.uint16), 'mono16')

        self.projection_publisher.publish(response.projection)
        self.projection_geometry_publisher.publish(response.projection.projection_geometry)

        return response
    
    def check_reachability_callback(self,
                                    request: ReachabilityCheck.Request,
                                    response: ReachabilityCheck.Response):
        self.get_logger().info('Pose to check ...')

        if self.collission_voxel_ocu_map is None: 
            response.status.success = False
            response.status.status_message = "Error: There is no collission voxel grid. Use the service `check_reachability_map`"
            return response

        # include some logic
        scan_pose = PyProjection.from_message(request.scan_pose)

        end_position = self.collission_voxel_ocu_map.roi.next_neighbor(self.source_position_grid, scan_pose.focal_spot_mm)
        end_position_check = self.collission_voxel_ocu_map.array[end_position[0],
                                                                 end_position[1],
                                                                 end_position[2]]

        if end_position_check == 0:
            response.checked_scan_pose.scan_pose = request.scan_pose
            response.checked_scan_pose.reachable = False
            response.checked_scan_pose.cost = 1000.
            self.get_logger().info('End can not be reached ...')
            return response

        if self.last_projection is None:
            response.checked_scan_pose.scan_pose = request.scan_pose
            response.checked_scan_pose.reachable = True
            response.checked_scan_pose.cost = 0.
            self.get_logger().info('First path! ...')
            return response

        start_position = self.collission_voxel_ocu_map.roi.next_neighbor(self.source_position_grid, self.last_projection.focal_spot_mm)
        start_position_check = self.collission_voxel_ocu_map.array[start_position[0],
                                                                start_position[1],
                                                                start_position[2]]
        
        if start_position_check == 0:
            response.checked_scan_pose.scan_pose = request.scan_pose
            response.checked_scan_pose.reachable = False
            response.checked_scan_pose.cost = 1000.
            self.get_logger().info('PStart can not be reached ...')
            return response

        path_cost = self.compute_path(start_position, end_position)
        
        response.checked_scan_pose.scan_pose = request.scan_pose
        response.checked_scan_pose.reachable = True
        response.checked_scan_pose.cost = float(path_cost)

        self.reachabilty_publisher.publish(response.checked_scan_pose)
        self.get_logger().info(f'Path can be reached. Cost: {path_cost}')
        return response
    
    def check_reachability_map_callback(self, 
                                        request: ReachabilityMapCheck.Request, 
                                        response: ReachabilityMapCheck.Response):
        
        if self.collission_mesh is None:
            response.status.success = False
            response.status.status_message = "Error: There is no collission mesh set. Use the topic `rq_artist_collision_mesh_mm`"
            return response
        
        roi = PyRegionOfIntrest.from_message(request.region_of_intrest)
        source_positions = roi.get_grid()
        center = np.array([request.center.x,
                           request.center.y,
                           request.center.z]).reshape((-1, 3))
        directions = source_positions - center
        directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
        detector_positions = directions * request.fdd_mm

        source_check = self.compute_if_points_inside_mesh(source_positions)
        detector_check = self.compute_if_points_inside_mesh(detector_positions)
        check: np.ndarray = np.logical_or(source_check, detector_check)

        check[0, :, :] = 0
        check[-1:, :, :] = 0

        check[:, 0, :] = 0
        check[:, -1:, :] = 0

        check[:, :, 0] = 0
        check[:, :, -1:] = 0

        self.collission_voxel_ocu_map = PyVolume(check.astype(np.uint8), roi, Volume.TYPE_UINT8)
        self.source_position_grid = self.collission_voxel_ocu_map.roi.get_grid()

        response.ocupation_map = self.collission_voxel_ocu_map.as_message()
        response.status.success = True
        self.get_logger().info('Added collission voxel grid ...')
        
        return response

    def set_collission_mesh_callback(self, msg: Mesh):
        self.get_logger().info('... set new collission object')
        self.collission_voxel_ocu_map = None
        self.source_position_grid = None


        o3d_mesh = o3d.geometry.TriangleMesh()
        vertices = np.array([[vertex.x, vertex.y, vertex.z] for vertex in msg.vertices])
        o3d_mesh.vertices = o3d.utility.Vector3dVector(vertices)
        triangles = np.array([[triangle.vertex_indices[0], triangle.vertex_indices[1], triangle.vertex_indices[2]] 
                            for triangle in msg.triangles])
        o3d_mesh.triangles = o3d.utility.Vector3iVector(triangles)

        self.collission_mesh = o3d_mesh
        self.get_logger().info('Added collission object ...')

    @staticmethod
    def from_open3d_to_message(mesh: o3d.geometry.TriangleMesh) -> Mesh:
        shape_mesh = Mesh()

        vertices = np.asarray(mesh.vertices)

        for vertex in vertices:
            point = Point()
            point.x = vertex[0]
            point.y = vertex[1]
            point.z = vertex[2]
            shape_mesh.vertices.append(point)

        triangles = np.asarray(mesh.triangles)

        for triangle in triangles:
            mesh_triangle = MeshTriangle()
            mesh_triangle.vertex_indices = [int(triangle[0]), int(triangle[1]), int(triangle[2])]
            shape_mesh.triangles.append(mesh_triangle)

        return shape_mesh
    
    def compute_if_points_inside_mesh(self, points: np.ndarray):
        """
        Outside is 1, inside is 0.
        """
        point_shape = points.shape
        points = np.float32(points)
        rays = np.concatenate((points, np.ones_like(points)), -1)

        scene = o3d.t.geometry.RaycastingScene()
        scene.add_triangles(np.float32(self.collission_mesh.vertices), np.uint32(self.collission_mesh.triangles))
        intersection_counts = scene.count_intersections(rays).numpy()
        is_outside = intersection_counts % 2 == 0
        return is_outside.reshape((point_shape[0], point_shape[1], point_shape[2], 1))
    
    def compute_path(self, start_indices, end_indices) -> float:
        grid = Grid(matrix=self.collission_voxel_ocu_map.array)
        start = grid.node(start_indices[0], start_indices[1], start_indices[2])
        end = grid.node(end_indices[0], end_indices[1], end_indices[2])

        finder = ThetaStarFinder(diagonal_movement=DiagonalMovement.always)
        path, runs = finder.find_path(start, end, grid)
        path = [p.identifier for p in path]

        self.get_logger().info(f'operations: {runs} path length {len(path)}')
        grid.cleanup()
        return self.calculate_path_cost(path)

    def calculate_path_cost(self, path, z_weight: float = 3.0) -> float:
        cost = 0
        for pt, pt_next in zip(path[:-1], path[1:]):
            dx, dy, dz = pt_next[0] - pt[0], pt_next[1] - pt[1], pt_next[2] - pt[2]
            
            dx = self.collission_voxel_ocu_map.roi.resolution_mm[0] * dx
            dy = self.collission_voxel_ocu_map.roi.resolution_mm[1] * dy
            dz = self.collission_voxel_ocu_map.roi.resolution_mm[2] * dz * z_weight

            cost += (dx**2 + dy**2 + dz**2) ** 0.5
        return cost



def main(args=None):
    rclpy.init(args=args)

    minimal_service = ArtistHardwareInterface()

    rclpy.spin(minimal_service)
    rclpy.shutdown()


if __name__ == '__main__':
    main()