import numpy as np
import json 
from pathlib import Path

from ..loader import BaseDataLoader, PyProjection, PyProjectionGeometry, PyRegionOfIntrest
from PIL import Image


class RqJsonLoader(BaseDataLoader):
    def __init__(self):
        super().__init__('.geom-json', '.tif', '.roi-json')

    def load_json(self, load_path: Path) -> dict:
        with open(str(load_path), 'r') as f:
            data_dict = json.load(f)
        return data_dict

    def load_projection_geometry(self, load_path: Path) -> PyProjectionGeometry:
        data_dict = self.load_json(load_path)

        projection_geometry = PyProjectionGeometry(
            focal_spot_mm=np.array([data_dict['focal_spot_mm']]),
            detector_postion_mm=np.array([data_dict['detector_postion_mm']]),
            detector_orientation_quad=np.array([data_dict['detector_orientation_quad']]),
            frame_id=data_dict['frame_id'])
        
        return projection_geometry
    
    def load_projection(self, load_path: Path) -> PyProjection:
        load_path_projection_geometry = load_path.parent / f'{load_path.stem}.{self.porjection_geometry_suffix}'
        data_dict = self.load_json(load_path_projection_geometry)
        image = np.asarray(Image.open(load_path))

        projection = PyProjection(
            focal_spot_mm=np.array([data_dict['focal_spot_mm']]),
            detector_postion_mm=np.array([data_dict['detector_postion_mm']]),
            detector_orientation_quad=np.array([data_dict['detector_orientation_quad']]),
            image=image,
            detector_heigth_mm=data_dict['detector_heigth_mm'],
            detector_width_mm=data_dict['detector_width_mm'],
            frame_id=data_dict['frame_id'])
        
        return projection
    
    def load_region_of_intrest(self, load_path: Path) -> PyRegionOfIntrest:
        data_dict = self.load_json(load_path)

        region_of_intrest = PyRegionOfIntrest(
            start_point_mm=np.array(data_dict['start_point_mm']),
            end_point_mm=np.array(data_dict['end_point_mm']))
        
        return region_of_intrest