import numpy as np
import json 
from pathlib import Path

from ..loader import BaseDataLoader, PyProjection, PyProjectionGeometry, PyRegionOfIntrest, PyVolume
from PIL import Image
import pyometiff


class RqJsonLoader(BaseDataLoader):
    def __init__(self):
        super().__init__('.geom-json', '.tif', '.roi-json', '.ome.tiff')

    def load_json(self, load_path: Path) -> dict:
        with open(str(load_path), 'r') as f:
            data_dict = json.load(f)
        return data_dict

    def load_projection_geometry(self, load_path: Path) -> PyProjectionGeometry:
        data_dict = self.load_json(load_path)

        projection_geometry = PyProjectionGeometry(
            focal_spot_mm=np.array([data_dict['focal_spot_mm']]),
            detector_postion_mm=np.array([data_dict['detector_postion_mm']]),
            detector_orientation_quad=np.array([data_dict['detector_orientation_quad']]),
            frame_id=data_dict['frame_id'],
            focal_spot_orientation_quad=np.array(data_dict['focal_spot_orientation_quad']))
        
        return projection_geometry
    
    def load_projection(self, load_path: Path) -> PyProjection:
        load_path_projection_geometry = load_path.parent / f'{load_path.stem}{self.porjection_geometry_suffix}'
        data_dict = self.load_json(load_path_projection_geometry)
        image = np.asarray(Image.open(load_path))

        projection = PyProjection(
            focal_spot_mm=np.array([data_dict['focal_spot_mm']]),
            detector_postion_mm=np.array([data_dict['detector_postion_mm']]),
            detector_orientation_quad=np.array([data_dict['detector_orientation_quad']]),
            image=image,
            detector_heigth_mm=data_dict['detector_heigth_mm'],
            detector_width_mm=data_dict['detector_width_mm'],
            frame_id=data_dict['frame_id'],
            focal_spot_orientation_quad=np.array(data_dict['focal_spot_orientation_quad']),
            voltage_kv=data_dict['voltage_kv'],
            current_ua=data_dict['current_ua'],
            exposure_time_ms=data_dict['exposure_time_ms'])
        
        return projection
    
    def load_region_of_intrest(self, load_path: Path) -> PyRegionOfIntrest:
        data_dict = self.load_json(load_path)

        region_of_intrest = PyRegionOfIntrest(
            center_points_mm=np.array(data_dict['center_points_mm']),
            dimensions_mm=np.array(data_dict['dimensions_mm']),
            frame_id=data_dict['frame_id'],
            resolution_mm=np.array(data_dict['resolution_mm']))
        
        return region_of_intrest
    
    def load_volume(self, load_path: Path) -> PyVolume:
        load_path_roi = load_path.parent / f'{load_path.stem}{self.region_of_intrest_suffix}'
        roi = self.load_region_of_intrest(load_path_roi)
        reader = pyometiff.OMETIFFReader(load_path)
        volume, _, _ = reader.read()
        
        if volume.dtype == np.uint16:
            data_type = 0
        elif volume.dtype == np.uint8:
            data_type = 1
        else:
            raise ValueError('data type not implemented.')
        
        return PyVolume(volume, roi, data_type)