from __future__ import annotations

import numpy as np

from rq_interfaces.msg import RegionOfIntrest
from visualization_msgs.msg import Marker


class PyRegionOfIntrest():
    """
    Represents a region of interest (ROI) with center points, dimensions, and resolution.

    Attributes:
        center_points_mm (np.ndarray): Center points of the ROIs in millimeters.
        dimensions_mm (np.ndarray): Dimensions of the ROIs in millimeters.
        frame_id (str): Frame ID for the ROI.
        resolution_mm (np.ndarray): Resolution of the ROIs in millimeters.
    """

    def __init__(self, center_points_mm: np.ndarray, dimensions_mm: np.ndarray, frame_id: str = 'object', 
                 resolution_mm: np.ndarray = np.array([0.1, 0.1, 0.1])):
        """
        Initializes a PyRegionOfIntrest instance.

        Args:
            center_points_mm (np.ndarray): Center points of the ROIs in millimeters.
            dimensions_mm (np.ndarray): Dimensions of the ROIs in millimeters.
            frame_id (str): Frame ID for the ROI. Default is 'object'.
            resolution_mm (np.ndarray): Resolution of the ROIs in millimeters. Default is [0.1, 0.1, 0.1].
        """

        self.center_points_mm = center_points_mm.reshape((-1, 3))
        self.dimensions_mm = dimensions_mm.reshape((-1, 3))
        self.frame_id = frame_id
        self.resolution_mm = resolution_mm.reshape((-1, 3))

    @classmethod
    def dummy(cls):
        """
        Creates a dummy instance of PyRegionOfIntrest for testing.

        Returns:
            PyRegionOfIntrest: A dummy instance with random values.
        """

        return cls((np.random.random((3, )) - 0.5) * 20., 
                   np.random.random((3, )) * 10.)

    @classmethod
    def from_message(cls, msg: RegionOfIntrest):
        """
        Creates an instance of PyRegionOfIntrest from a ROS message.

        Args:
            msg (RegionOfIntrest): The ROS message containing ROI data.

        Returns:
            PyRegionOfIntrest: An instance initialized from the ROS message.
        """

        center_points_mm = list()
        dimensions_mm = list()

        for roi in msg.region_of_intrest_stack.markers:
            roi: Marker
            center_points_mm.append(
                np.array([roi.pose.position.x,
                          roi.pose.position.y,
                          roi.pose.position.z]))
            
            dimensions_mm.append(
                np.array([roi.scale.x,
                          roi.scale.y,
                          roi.scale.z]))
            
            frame = roi.header.frame_id
        resolution_mm = np.array([msg.resolution.x,
                                  msg.resolution.y,
                                  msg.resolution.z])
        
        return cls(np.array(center_points_mm), np.array(dimensions_mm), frame, resolution_mm)
    
    @property
    def number_of_rois(self) -> int:
        """
        Returns the number of regions of interest.

        Returns:
            int: The number of ROIs.
        """

        return self.center_points_mm.shape[0]
    
    @property
    def shape(self) -> tuple:
        """
        Returns the shape of the ROI grid based on the dimensions and resolution.

        Returns:
            tuple: The shape of the ROI grid.
        """

        shape = self.dimensions_mm[0] // self.resolution_mm[0]
        return (int(shape[0]), int(shape[1]), int(shape[2]))
      
    def as_message(self) -> RegionOfIntrest:
        """
        Converts the PyRegionOfIntrest instance to a ROS message.

        Returns:
            RegionOfIntrest: The ROS message representing the ROI.
        """

        message = RegionOfIntrest()
        roi_list = list()

        for i in range(self.number_of_rois):
            roi = Marker()

            roi.pose.position.x = float(self.center_points_mm[i][0])
            roi.pose.position.y = float(self.center_points_mm[i][1])
            roi.pose.position.z = float(self.center_points_mm[i][2])

            roi.scale.x = float(self.dimensions_mm[i][0])
            roi.scale.y = float(self.dimensions_mm[i][1])
            roi.scale.z = float(self.dimensions_mm[i][2])

            roi.header.frame_id = self.frame_id

            roi_list.append(roi)            
        
        message.region_of_intrest_stack.markers = roi_list

        message.resolution.x = float(self.resolution_mm[0][0])
        message.resolution.y = float(self.resolution_mm[0][1])
        message.resolution.z = float(self.resolution_mm[0][2])

        return message
    
    def get_grid(self, indice: int = 0) -> np.ndarray:
        """
        Generates a grid of points within the ROI.

        Args:
            indice (int): Index of the ROI. Default is 0.

        Returns:
            np.ndarray: A grid of points within the ROI.
        """

        start = self.center_points_mm[indice] - (self.dimensions_mm[indice] / 2.)
        end = self.center_points_mm[indice] + (self.dimensions_mm[indice] / 2.)
        
        x_ = np.linspace(start[0], end[0], self.shape[0])
        y_ = np.linspace(start[1], end[1], self.shape[1])
        z_ = np.linspace(start[2], end[2], self.shape[2])

        x, y, z = np.meshgrid(x_, y_, z_, indexing='ij')

        return np.concatenate((
            np.expand_dims(x, -1),
            np.expand_dims(y, -1),
            np.expand_dims(z, -1)),
            -1)
    
    @staticmethod
    def next_neighbor(grid_mm: np.ndarray, point_mm: np.ndarray) -> np.ndarray:
        """
        Finds the nearest neighbor in the grid to a given point.

        Args:
            grid_mm (np.ndarray): The grid of points.
            point_mm (np.ndarray): The point to find the nearest neighbor for.

        Returns:
            np.ndarray: The indices of the nearest neighbor in the grid.
        """
        
        x = grid_mm[:, 0, 0, 0]
        y = grid_mm[0, :, 0, 1]
        z = grid_mm[0, 0, :, 2]

        xx = int(np.argmin((x - point_mm[0])**2))
        yy = int(np.argmin((y - point_mm[1])**2))
        zz = int(np.argmin((z - point_mm[2])**2))

        return np.array([xx, yy, zz], dtype=np.int32)