Skip to content
Snippets Groups Projects
load.py 2.05 KiB
import numpy as np
import json 
from pathlib import Path

from ..loader import BaseDataLoader, PyProjection, PyProjectionGeometry, PyRegionOfIntrest
from PIL import Image


class RqJsonLoader(BaseDataLoader):
    def __init__(self):
        super().__init__('.geom-json', '.tif', '.roi-json')

    def load_json(self, load_path: Path) -> dict:
        with open(str(load_path), 'r') as f:
            data_dict = json.load(f)
        return data_dict

    def load_projection_geometry(self, load_path: Path) -> PyProjectionGeometry:
        data_dict = self.load_json(load_path)

        projection_geometry = PyProjectionGeometry(
            focal_spot_mm=np.array([data_dict['focal_spot_mm']]),
            detector_postion_mm=np.array([data_dict['detector_postion_mm']]),
            detector_orientation_quad=np.array([data_dict['detector_orientation_quad']]),
            frame_id=data_dict['frame_id'])
        
        return projection_geometry
    
    def load_projection(self, load_path: Path) -> PyProjection:
        load_path_projection_geometry = load_path.parent / f'{load_path.stem}.{self.porjection_geometry_suffix}'
        data_dict = self.load_json(load_path_projection_geometry)
        image = np.asarray(Image.open(load_path))

        projection = PyProjection(
            focal_spot_mm=np.array([data_dict['focal_spot_mm']]),
            detector_postion_mm=np.array([data_dict['detector_postion_mm']]),
            detector_orientation_quad=np.array([data_dict['detector_orientation_quad']]),
            image=image,
            detector_heigth_mm=data_dict['detector_heigth_mm'],
            detector_width_mm=data_dict['detector_width_mm'],
            frame_id=data_dict['frame_id'])
        
        return projection
    
    def load_region_of_intrest(self, load_path: Path) -> PyRegionOfIntrest:
        data_dict = self.load_json(load_path)

        region_of_intrest = PyRegionOfIntrest(
            start_point_mm=np.array(data_dict['start_point_mm']),
            end_point_mm=np.array(data_dict['end_point_mm']))
        
        return region_of_intrest