from __future__ import annotations

from rq_interfaces.srv import AquireProjection, ReachabilityCheck, ReachabilityMapCheck

from rq_controller.common import PyProjection, PyRegionOfIntrest, PyVolume

import rclpy
from rclpy.node import Node
import numpy as np


class HardwareClient(Node):
    """
    HardwareClient is a ROS 2 Node that interacts with several services to acquire projections
    and check reachability of given projections and regions of interest.
    """

    def __init__(self, **kwargs):
        """
        Initialize the HardwareClient node, create service clients, and wait for services to be available.

        :param kwargs: Additional keyword arguments to pass to the Node initialization.
        """
        super().__init__("rq_hardware_client", namespace="rq", **kwargs)
        self.cli_projection = self.create_client(AquireProjection, "aquire_projection")
        while not self.cli_projection.wait_for_service(timeout_sec=1.0):
            self.get_logger().info("projection service not available, waiting again...")
        self.cli_reachability = self.create_client(
            ReachabilityCheck, "check_reachability"
        )
        while not self.cli_projection.wait_for_service(timeout_sec=1.0):
            self.get_logger().info(
                "reachability service not available, waiting again..."
            )
        self.cli_reachability_map = self.create_client(
            ReachabilityMapCheck, "check_reachability_map"
        )
        while not self.cli_projection.wait_for_service(timeout_sec=1.0):
            self.get_logger().info(
                "reachability MAP service not available, waiting again..."
            )

        self.req_projection = AquireProjection.Request()
        self.req_reachability = ReachabilityCheck.Request()
        self.req_reachability_map = ReachabilityMapCheck.Request()

    def aquire_projection(self, projection: PyProjection) -> rclpy.Future:
        """
        Send a request to acquire a projection.

        :param projection: The PyProjection object containing the scan pose.
        :return: A Future object representing the result of the asynchronous service call.
        """
        self.req_projection.scan_pose = projection.as_message()
        return self.cli_projection.call_async(self.req_projection)

    def check_reachability(self, projection: PyProjection) -> rclpy.Future:
        """
        Send a request to check the reachability of a given projection.

        :param projection: The PyProjection object containing the scan pose.
        :return: A Future object representing the result of the asynchronous service call.
        """
        self.req_reachability.scan_pose = projection.as_message()
        return self.cli_reachability.call_async(self.req_reachability)

    def check_reachability_map(
        self,
        source_position_roi_mm: PyRegionOfIntrest,
        center_mm: np.ndarray,
        fdd_mm: float,
    ) -> rclpy.Future:
        """
        Send a request to check the reachability map for a given region of interest and focal detector distance.

        :param source_position_roi_mm: The PyRegionOfIntrest object containing the region of interest.
        :param center_mm: A numpy array representing the center position in millimeters.
        :param fdd_mm: The focal detector distance in millimeters.
        :return: A Future object representing the result of the asynchronous service call.
        """
        center_mm = center_mm.reshape((3,))
        self.req_reachability_map.center.x = center_mm[0]
        self.req_reachability_map.center.y = center_mm[1]
        self.req_reachability_map.center.z = center_mm[2]
        self.req_reachability_map.region_of_intrest = (
            source_position_roi_mm.as_message()
        )
        self.req_reachability_map.fdd_mm = fdd_mm
        return self.cli_reachability_map.call_async(self.req_reachability_map)

    @staticmethod
    def projection_response_2_py(response: AquireProjection.Response) -> PyProjection:
        """
        Convert an AquireProjection response to a PyProjection object.

        :param response: The AquireProjection response message.
        :return: A PyProjection object.
        """
        return PyProjection.from_message(response.projection)

    @staticmethod
    def reachability_response_2_py(
        response: ReachabilityCheck.Response,
    ) -> tuple[bool, float, PyProjection]:
        """
        Convert a ReachabilityCheck response to a tuple containing reachability information.

        :param response: The ReachabilityCheck response message.
        :return: A tuple containing (reachable: bool, cost: float, scan_pose: PyProjection).
        """
        return (
            response.checked_scan_pose.reachable,
            response.checked_scan_pose.cost,
            PyProjection.from_message(response.checked_scan_pose.scan_pose),
        )

    @staticmethod
    def reachability_map_response_2_py(
        response: ReachabilityMapCheck.Response,
    ) -> PyVolume:
        """
        Convert a ReachabilityMapCheck response to a PyVolume object.

        :param response: The ReachabilityMapCheck response message.
        :return: A PyVolume object.
        """
        return PyVolume.from_message(response.ocupation_map)


def main():
    """
    Main function to initialize the ROS 2 node, perform reachability checks, and acquire projections.
    """
    from matplotlib import pyplot as plt

    rclpy.init()

    minimal_client = HardwareClient()
    scan_pose = PyProjection.dummy()

    future = minimal_client.check_reachability(scan_pose)
    rclpy.spin_until_future_complete(minimal_client, future)

    response = future.result()
    reachability_check = minimal_client.reachability_response_2_py(response)
    minimal_client.get_logger().info(
        f"Pose is: {reachability_check[0]} \nMove cost: {reachability_check[1]}"
    )

    future = minimal_client.aquire_projection(scan_pose)
    rclpy.spin_until_future_complete(minimal_client, future)

    response = future.result()
    projection = minimal_client.projection_response_2_py(response)
    minimal_client.get_logger().info("Received Projection ...")

    minimal_client.destroy_node()
    rclpy.shutdown()

    plt.imshow(projection.image)
    plt.show()


if __name__ == "__main__":
    main()