from __future__ import annotations

from rq_interfaces.msg import Projection
from rq_interfaces.srv import AquireProjection, ReachabilityCheck, ReachabilityMapCheck

from rq_controller.common import PyProjection, PyRegionOfIntrest, PyVolume

import rclpy
from rclpy.node import Node
import numpy as np


class HardwareClient(Node):
    """
    HardwareClient is a ROS 2 Node that interacts with several services to acquire projections 
    and check reachability of given projections and regions of interest.
    """

    def __init__(self, **kwargs):
        """
        Initialize the HardwareClient node, create service clients, and wait for services to be available.

        :param kwargs: Additional keyword arguments to pass to the Node initialization.
        """
        super().__init__('rq_hardware_client', namespace="rq", **kwargs)
        self.cli_projection = self.create_client(AquireProjection, 'aquire_projection')
        while not self.cli_projection.wait_for_service(timeout_sec=1.0):
            self.get_logger().info('projection service not available, waiting again...')
        self.cli_reachability = self.create_client(ReachabilityCheck, 'check_reachability')
        while not self.cli_projection.wait_for_service(timeout_sec=1.0):
            self.get_logger().info('reachability service not available, waiting again...')
        self.cli_reachability_map = self.create_client(ReachabilityMapCheck, 'check_reachability_map')
        while not self.cli_projection.wait_for_service(timeout_sec=1.0):
            self.get_logger().info('reachability MAP service not available, waiting again...')
        
        self.req_projection = AquireProjection.Request()
        self.req_reachability = ReachabilityCheck.Request()
        self.req_reachability_map = ReachabilityMapCheck.Request()

    def aquire_projection(self, projection: PyProjection) -> rclpy.Future:
        """
        Send a request to acquire a projection.

        :param projection: The PyProjection object containing the scan pose.
        :return: A Future object representing the result of the asynchronous service call.
        """
        self.req_projection.scan_pose = projection.as_message()
        return self.cli_projection.call_async(self.req_projection)
    
    def check_reachability(self, projection: PyProjection) -> rclpy.Future:
        """
        Send a request to check the reachability of a given projection.

        :param projection: The PyProjection object containing the scan pose.
        :return: A Future object representing the result of the asynchronous service call.
        """
        self.req_reachability.scan_pose = projection.as_message()
        return self.cli_reachability.call_async(self.req_reachability)
    
    def check_reachability_map(self, source_position_roi_mm: PyRegionOfIntrest, center_mm: np.ndarray, fdd_mm: float) -> rclpy.Future:
        """
        Send a request to check the reachability map for a given region of interest and focal detector distance.

        :param source_position_roi_mm: The PyRegionOfIntrest object containing the region of interest.
        :param center_mm: A numpy array representing the center position in millimeters.
        :param fdd_mm: The focal detector distance in millimeters.
        :return: A Future object representing the result of the asynchronous service call.
        """
        center_mm = center_mm.reshape((3,))
        self.req_reachability_map.center.x = center_mm[0]
        self.req_reachability_map.center.y = center_mm[1]
        self.req_reachability_map.center.z = center_mm[2]
        self.req_reachability_map.region_of_intrest = source_position_roi_mm.as_message()
        self.req_reachability_map.fdd_mm = fdd_mm
        return self.cli_reachability_map.call_async(self.req_reachability_map)
         
    @staticmethod
    def projection_response_2_py(response: AquireProjection.Response) -> PyProjection:
        """
        Convert an AquireProjection response to a PyProjection object.

        :param response: The AquireProjection response message.
        :return: A PyProjection object.
        """
        return PyProjection.from_message(response.projection)
    
    @staticmethod
    def reachability_response_2_py(response: ReachabilityCheck.Response) -> tuple[bool, float, PyProjection]:
        """
        Convert a ReachabilityCheck response to a tuple containing reachability information.

        :param response: The ReachabilityCheck response message.
        :return: A tuple containing (reachable: bool, cost: float, scan_pose: PyProjection).
        """
        return (response.checked_scan_pose.reachable, 
                response.checked_scan_pose.cost, 
                PyProjection.from_message(response.checked_scan_pose.scan_pose))
    
    @staticmethod
    def reachability_map_response_2_py(response: ReachabilityMapCheck.Response) -> PyVolume:
        """
        Convert a ReachabilityMapCheck response to a PyVolume object.

        :param response: The ReachabilityMapCheck response message.
        :return: A PyVolume object.
        """
        return PyVolume.from_message(response.ocupation_map)
    

def main():
    """
    Main function to initialize the ROS 2 node, perform reachability checks, and acquire projections.
    """
    from matplotlib import pyplot as plt

    rclpy.init()

    minimal_client = HardwareClient()
    scan_pose = PyProjection.dummy()

    future = minimal_client.check_reachability(scan_pose)
    rclpy.spin_until_future_complete(minimal_client, future)

    response = future.result()
    reachability_check = minimal_client.reachability_response_2_py(response)
    minimal_client.get_logger().info(
        f'Pose is: {reachability_check[0]} \nMove cost: {reachability_check[1]}')

    future = minimal_client.aquire_projection(scan_pose)
    rclpy.spin_until_future_complete(minimal_client, future)

    response = future.result()
    projection = minimal_client.projection_response_2_py(response)
    minimal_client.get_logger().info(
        f'Received Projection ...')

    minimal_client.destroy_node()
    rclpy.shutdown()

    plt.imshow(projection.image)
    plt.show()


if __name__ == '__main__':
    main()