import numpy as np
import json 
from pathlib import Path

from ..writer import BaseDataWriter, PyProjection, PyProjectionGeometry, PyRegionOfIntrest, PyVolume
from PIL import Image
import pyometiff


class RqJsonWriter(BaseDataWriter):
    def __init__(self):
        super().__init__('.geom-json', '.tif', '.roi-json', '.ome.tiff')

    def write_json(self, save_path: Path, data_dict: dict):
        with open(str(save_path), 'w') as f:
            json.dump(data_dict, f, indent=2)

    def write_projection_geometry(self, save_path: Path, projection_geometry: PyProjectionGeometry):
        data_dict = dict()

        #geometry
        data_dict['focal_spot_mm'] = projection_geometry.focal_spot_mm.tolist()
        data_dict['detector_postion_mm'] = projection_geometry.detector_postion_mm.tolist()
        data_dict['detector_orientation_quad'] = projection_geometry.detector_orientation_quad.tolist()
        data_dict['focal_spot_orientation_quad'] = projection_geometry.focal_spot_orientation_quad.tolist()
        
        data_dict['frame_id'] = projection_geometry.frame_id

        self.write_json(save_path, data_dict)

    def write_projection(self, save_path: Path, projection: PyProjection):
        save_path_projection_geometry = save_path.parent / f'{save_path.stem}{self.porjection_geometry_suffix}'
        data_dict = dict()

        # geometry
        data_dict['focal_spot_mm'] = projection.focal_spot_mm.tolist()
        data_dict['detector_postion_mm'] = projection.detector_postion_mm.tolist()
        data_dict['detector_orientation_quad'] = projection.detector_orientation_quad.tolist()
        data_dict['focal_spot_orientation_quad'] = projection.focal_spot_orientation_quad.tolist()
        
        # detector
        data_dict['pixel_pitch_x_mm'] = projection.pixel_pitch_x_mm
        data_dict['pixel_pitch_y_mm'] = projection.pixel_pitch_y_mm
        data_dict['detector_heigth_mm'] = projection.detector_heigth_mm
        data_dict['detector_width_mm'] = projection.detector_width_mm

        # xray
        data_dict['exposure_time_ms'] = projection.exposure_time_ms
        data_dict['voltage_kv'] = projection.voltage_kv
        data_dict['current_ua'] = projection.current_ua
        data_dict['frame_id'] = projection.frame_id

        self.write_json(save_path_projection_geometry, data_dict)

        image = Image.fromarray(projection.image)
        image.save(save_path)

    def write_region_of_intrest(self, save_path: Path, region_of_intrest: PyRegionOfIntrest):
        data_dict = dict()

        # geometry
        data_dict['center_points_mm'] = region_of_intrest.center_points_mm.tolist()
        data_dict['dimensions_mm'] = region_of_intrest.dimensions_mm.tolist()
        data_dict['frame_id'] = region_of_intrest.frame_id
        data_dict['resolution_mm'] = region_of_intrest.resolution_mm.tolist()

        self.write_json(save_path, data_dict)

    def write_volume(self, save_path: Path, volume: PyVolume):
        save_path_region_of_intrest = save_path.parent / f'{save_path.stem}{self.region_of_intrest_suffix}'
        self.write_region_of_intrest(save_path_region_of_intrest, volume.roi)
        writer = pyometiff.OMETIFFWriter(save_path, volume.array, dict())
        writer.write()