From 2274c1283f087657eac57b9636080a39d00f7734 Mon Sep 17 00:00:00 2001 From: swittl <simon.wittl@th-deg.de> Date: Wed, 30 Oct 2024 07:10:24 +0100 Subject: [PATCH] added data augmentation and --- CyXTraX/io/load_maps.py | 7 +-- CyXTraX/simulation/artist_bridge.py | 17 +++++-- CyXTraX/simulation/model/__init__.py | 4 ++ .../model/augemtation_parameters.py | 10 ++++ CyXTraX/simulation/model/noise.py | 36 ++++++++++++++ CyXTraX/simulation/model/xray.py | 38 ++++++++++++++ CyXTraX/util/__init__.py | 0 CyXTraX/util/datasets/__init__.py | 2 + CyXTraX/util/datasets/merge.py | 14 ++++++ CyXTraX/util/datasets/pipeline.py | 30 ++++++++++++ CyXTraX/util/visualisation/__init__.py | 1 + CyXTraX/util/visualisation/gif.py | 7 +++ CyXTraX/util/visualisation/plt.py | 25 ++++++++++ scripts/05_data_augmentation.py | 49 +++++++++++++++++++ scripts/06_generate_dataset.py | 23 +++++++++ setup.cfg | 4 +- 16 files changed, 258 insertions(+), 9 deletions(-) create mode 100644 CyXTraX/simulation/model/__init__.py create mode 100644 CyXTraX/simulation/model/augemtation_parameters.py create mode 100644 CyXTraX/simulation/model/noise.py create mode 100644 CyXTraX/simulation/model/xray.py create mode 100644 CyXTraX/util/__init__.py create mode 100644 CyXTraX/util/datasets/__init__.py create mode 100644 CyXTraX/util/datasets/merge.py create mode 100644 CyXTraX/util/datasets/pipeline.py create mode 100644 CyXTraX/util/visualisation/__init__.py create mode 100644 CyXTraX/util/visualisation/gif.py create mode 100644 CyXTraX/util/visualisation/plt.py create mode 100644 scripts/05_data_augmentation.py create mode 100644 scripts/06_generate_dataset.py diff --git a/CyXTraX/io/load_maps.py b/CyXTraX/io/load_maps.py index b2cddbd..eb42e50 100644 --- a/CyXTraX/io/load_maps.py +++ b/CyXTraX/io/load_maps.py @@ -1,6 +1,6 @@ import h5py from jax import numpy as jnp -from ..common.mesh_object import MeshObject +from CyXTraX.common.mesh_object import MeshObject import json from pathlib import Path @@ -13,11 +13,12 @@ def load_atlas(load_path: Path ) -> tuple[jnp.ndarray, jnp.ndarray, list[MeshObject]]: with h5py.File(load_path, 'r') as f: - maps = f['/maps'][:] - points = f['/positions'][:] + maps = jnp.array(f['/maps'][:]) + points = jnp.array(f['/positions'][:]) mesh_object_str = f.attrs['mesh_list'] mesh_object_dict = json.loads(mesh_object_str) mesh_object_list = list(map(from_dict, mesh_object_dict)) + return maps, points, mesh_object_list diff --git a/CyXTraX/simulation/artist_bridge.py b/CyXTraX/simulation/artist_bridge.py index e9243b6..9e749ba 100644 --- a/CyXTraX/simulation/artist_bridge.py +++ b/CyXTraX/simulation/artist_bridge.py @@ -4,7 +4,7 @@ import numpy as np import os from scipy.spatial.transform import Rotation import importlib.resources - +from time import sleep from CyXTraX.common.mesh_object import MeshObject # !!!!!!!!!!!!!!!!!!! @@ -21,6 +21,8 @@ from CyXTraX.common.mesh_object import MeshObject # The Script assumes this geometry and only moves all the objects! # Preview must be reseted!!! +GLOBAL_COUNTER = 0 + class CylindricalProjection: def __init__(self, api: API = API()) -> None: self.api = api @@ -55,11 +57,16 @@ class CylindricalProjection: self.api.rc.send(f'::XDetector::SetDownCurvedView;') - def compute_projection(self, position: np.ndarray, temp_file_path: Path = Path('temp.tiff'), output_full_ray_projection: bool = True) -> np.ndarray: + def compute_projection(self, position: np.ndarray, temp_file_path: Path = Path(r'C:\data'), output_full_ray_projection: bool = True) -> np.ndarray: + global GLOBAL_COUNTER self.translate(position) - self.api.save_image(temp_file_path, save_projection_geometry=False, save_mode=SAVEMODES.FLOAT_TIFF) - image = utility.load_projection(temp_file_path, load_projection_geometry=False)[0] - os.remove(temp_file_path) + temp_file = temp_file_path / f'temp_{GLOBAL_COUNTER:01}.tiff' + GLOBAL_COUNTER += 1 + GLOBAL_COUNTER = GLOBAL_COUNTER % 10 + self.api.save_image(temp_file, save_projection_geometry=False, save_mode=SAVEMODES.FLOAT_TIFF) + + image = utility.load_projection(temp_file, load_projection_geometry=False)[0] + os.remove(temp_file) if output_full_ray_projection: return self.convert_rays(image) diff --git a/CyXTraX/simulation/model/__init__.py b/CyXTraX/simulation/model/__init__.py new file mode 100644 index 0000000..ab04eac --- /dev/null +++ b/CyXTraX/simulation/model/__init__.py @@ -0,0 +1,4 @@ + +from .augemtation_parameters import AugemntationParameter +from .noise import add_white_noise, add_gaussian_blur +from .xray import add_material \ No newline at end of file diff --git a/CyXTraX/simulation/model/augemtation_parameters.py b/CyXTraX/simulation/model/augemtation_parameters.py new file mode 100644 index 0000000..51f5e1b --- /dev/null +++ b/CyXTraX/simulation/model/augemtation_parameters.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass +from jax import random + + +@dataclass +class AugemntationParameter: + augmentation_module: str + augmentation_function: str + function_parameter: dict + input_parameter: dict \ No newline at end of file diff --git a/CyXTraX/simulation/model/noise.py b/CyXTraX/simulation/model/noise.py new file mode 100644 index 0000000..76e1a47 --- /dev/null +++ b/CyXTraX/simulation/model/noise.py @@ -0,0 +1,36 @@ +from .augemtation_parameters import AugemntationParameter +import dm_pix as pix +from jax import numpy as jnp, random + + + +def add_white_noise(maps: jnp.ndarray, mu: jnp.ndarray, sigma: jnp.ndarray, key: random.PRNGKey + ) -> tuple[jnp.ndarray, AugemntationParameter, random.PRNGKey]: + key2use, key4next = random.split(key) + function_kwargs = dict(key=key2use, shape=maps.shape, dtype=maps.dtype) + input_kwargs = dict(mu=mu, sigma=sigma) + maps = maps.at[:].add(random.normal(**function_kwargs) * sigma + mu) + + parameters = AugemntationParameter( + add_gaussian_blur.__globals__['__name__'], + add_gaussian_blur.__name__, + function_kwargs, + input_kwargs) + + return maps, parameters, key4next + + +def add_gaussian_blur(maps: jnp.ndarray, sigma: jnp.ndarray, kernel_size: jnp.ndarray, key: random.PRNGKey + ) -> tuple[jnp.ndarray, AugemntationParameter, random.PRNGKey]: + key4next = key + function_kwargs = dict(sigma=sigma, kernel_size=kernel_size) + input_kwargs = dict(kernel_size=kernel_size, sigma=sigma) + maps = pix.gaussian_blur(maps, **function_kwargs) + + parameters = AugemntationParameter( + add_gaussian_blur.__globals__['__name__'], + add_gaussian_blur.__name__, + function_kwargs, + input_kwargs) + + return maps, parameters, key4next diff --git a/CyXTraX/simulation/model/xray.py b/CyXTraX/simulation/model/xray.py new file mode 100644 index 0000000..978af9c --- /dev/null +++ b/CyXTraX/simulation/model/xray.py @@ -0,0 +1,38 @@ +from .augemtation_parameters import AugemntationParameter + +import xraylib +import numpy as np +from jax import numpy as jnp, random + + +def get_attenuation_coefficient_monochrom(element: str, energy_keV: float) -> float: + element = xraylib.SymbolToAtomicNumber("C") + return xraylib.CS_Total(element, energy_keV) + + +def get_attenuation_coefficient_polychrom(element: str, energy_keV: float, bins: int) -> float: + energies = jnp.linspace(0, energy_keV, bins, endpoint=True) + for i in range(1, bins): + energies = energies.at[i].set(get_attenuation_coefficient_monochrom(element, float(energies[i]))) + return energies.sum() + + +def add_material(maps: jnp.ndarray, materilal: str, energy_keV: float, bins: int, key: random.PRNGKey): + key4next = key + function_kwargs = dict() + input_kwargs = dict(element=materilal, energy_keV=energy_keV, bins=bins) + attenuation_coefficient = get_attenuation_coefficient_polychrom(**input_kwargs) + + parameters = AugemntationParameter( + add_material.__globals__['__name__'], + add_material.__name__, + function_kwargs, + input_kwargs) + maps = maps.at[:].multiply(attenuation_coefficient) + + return maps, parameters, key4next + + +if __name__ == '__main__': + mu = get_attenuation_coefficient_polychrom('c', 100, 10) + print(mu) \ No newline at end of file diff --git a/CyXTraX/util/__init__.py b/CyXTraX/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/CyXTraX/util/datasets/__init__.py b/CyXTraX/util/datasets/__init__.py new file mode 100644 index 0000000..4a6739f --- /dev/null +++ b/CyXTraX/util/datasets/__init__.py @@ -0,0 +1,2 @@ +from .merge import merge_atlas_and_dict, merge_atlases, merge_lists +from .pipeline import generate_atlas_from_folder \ No newline at end of file diff --git a/CyXTraX/util/datasets/merge.py b/CyXTraX/util/datasets/merge.py new file mode 100644 index 0000000..d1aae3f --- /dev/null +++ b/CyXTraX/util/datasets/merge.py @@ -0,0 +1,14 @@ +from jax import numpy as jnp +from CyXTraX.common.mesh_object import MeshObject + + +def merge_atlases(carry_atlas: jnp.ndarray, atlas: jnp.ndarray) -> jnp.ndarray: + return carry_atlas.at[:].add(atlas) / 2. + + +def merge_lists(carry_list: list, x: list) -> list: + return carry_list.extend(x) + + +def merge_atlas_and_dict(atlas_one, mesh_list_one, atlas_two, mesh_list_two) -> tuple[jnp.ndarray, list]: + return merge_atlases(atlas_one, atlas_two), merge_lists(mesh_list_one, mesh_list_two) diff --git a/CyXTraX/util/datasets/pipeline.py b/CyXTraX/util/datasets/pipeline.py new file mode 100644 index 0000000..7f29167 --- /dev/null +++ b/CyXTraX/util/datasets/pipeline.py @@ -0,0 +1,30 @@ +from pathlib import Path +from CyXTraX.simulation import CylindricalProjection, load_mesh +from CyXTraX.common import MeshObject +from CyXTraX.io import record_atlas +from jax import random, numpy as jnp +from jax.scipy.spatial.transform import Rotation + +def generate_atlas_from_folder( + folder: Path, key: random.PRNGKey, + save_folder: Path, + map_bounding_box: tuple[float]=(-50., 50.), number_of_maps: int = 4): + randint_key, postion_key, orientation_key = random.split(key, 3) + stl_files = list(folder.glob('*.stl')) + number_of_files = len(stl_files) + + item_id = random.randint(randint_key, (1,), 0, number_of_files, dtype=jnp.int32) + random_euler = random.uniform(orientation_key, (3,), minval=-180, maxval=180) + rot: Rotation = Rotation.from_euler('xyz', random_euler) + stl_path: Path = stl_files[item_id[0]] + mesh_item = MeshObject( + stl_path, + random.truncated_normal(postion_key, -50, 50, (3,)) + random.uniform(postion_key, (3,), minval=-25, maxval=25), + rot.as_quat()) + mesh_list = [mesh_item] + cycl_proj = CylindricalProjection() + load_mesh(cycl_proj, mesh_list) + record_atlas(cycl_proj, mesh_list, + stl_path.stem, save_folder, + map_bounding_box, number_of_maps) + diff --git a/CyXTraX/util/visualisation/__init__.py b/CyXTraX/util/visualisation/__init__.py new file mode 100644 index 0000000..a637804 --- /dev/null +++ b/CyXTraX/util/visualisation/__init__.py @@ -0,0 +1 @@ +from .plt import save_map_plot_with_index \ No newline at end of file diff --git a/CyXTraX/util/visualisation/gif.py b/CyXTraX/util/visualisation/gif.py new file mode 100644 index 0000000..146a4fc --- /dev/null +++ b/CyXTraX/util/visualisation/gif.py @@ -0,0 +1,7 @@ +from CyXTraX.io import load_atlas +from pathlib import Path + + +def atlas_gif(load_path: Path): + maps, map_position, meshes = load_atlas(load_path) + \ No newline at end of file diff --git a/CyXTraX/util/visualisation/plt.py b/CyXTraX/util/visualisation/plt.py new file mode 100644 index 0000000..c992c37 --- /dev/null +++ b/CyXTraX/util/visualisation/plt.py @@ -0,0 +1,25 @@ +from matplotlib import pyplot as plt +from jax import numpy as jnp +from pathlib import Path + + +def save_map_plot_with_index(maps: jnp.ndarray, index: int, save_path: Path, figsize=(8, 8), + title: str = None): + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111) + ax.imshow(maps[:, :, index]) + ax.set_xlabel(r'$z$ / mm') + ax.set_ylabel(r'$\alpha$ / rad') + ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])]) + ax.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$']) + ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])]) + ax.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$']) + + if title is not None: + ax.set_title(title) + + plt.tight_layout() + + plt.savefig(save_path) + plt.close(fig) + \ No newline at end of file diff --git a/scripts/05_data_augmentation.py b/scripts/05_data_augmentation.py new file mode 100644 index 0000000..2ab1212 --- /dev/null +++ b/scripts/05_data_augmentation.py @@ -0,0 +1,49 @@ +from CyXTraX.io import load_atlas +from CyXTraX.simulation.model import add_white_noise, add_gaussian_blur, add_material +from CyXTraX.util.visualisation import save_map_plot_with_index +from CyXTraX.util.datasets import merge_atlases + +from pathlib import Path +from jax import random + + +# Just get the temp folder to store the atlas +FOLDER = Path(__file__) +TEMP_FOLDER = FOLDER.parent.parent / 'temp' + + +def main(): + files = list(TEMP_FOLDER.glob('*.h5')) + if not files: + raise FileNotFoundError('No .h5 files in temp folder. Run Example 02!') + + map_index = 3 + maps, points, mesh_object_list = load_atlas(files[-1]) + random_key = random.PRNGKey(42) + save_map_plot_with_index(maps, map_index, TEMP_FOLDER / 'blank_map.png') + + map_with_white_noise, white_noise_parameter, random_key = add_white_noise(maps, 5.3, 10.1, random_key) + print(white_noise_parameter) + save_map_plot_with_index( + map_with_white_noise, map_index, TEMP_FOLDER / 'noise_map.png', + title='White Noise') + + map_with_gaussian_blur, gaussian_blur_parameter, random_key = add_gaussian_blur(maps, 1.1, 3, random_key) + print(gaussian_blur_parameter) + save_map_plot_with_index( + map_with_gaussian_blur, map_index, TEMP_FOLDER / 'gaussian_map.png', + title='Gaussian Blur') + + map_with_attenuation, attenuation_parameter, random_key = add_material(maps, 'c', 225, 30, random_key) + print(attenuation_parameter) + save_map_plot_with_index( + map_with_attenuation, map_index, TEMP_FOLDER / 'attenuation_c_map.png', + title='Attenuation C') + + merged_map = merge_atlases(map_with_gaussian_blur, map_with_white_noise) + save_map_plot_with_index( + merged_map, map_index, TEMP_FOLDER / 'merged_map.png', + title='Merged Map: Gaussian + Noise') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/06_generate_dataset.py b/scripts/06_generate_dataset.py new file mode 100644 index 0000000..5d03fda --- /dev/null +++ b/scripts/06_generate_dataset.py @@ -0,0 +1,23 @@ +from CyXTraX.util.datasets import generate_atlas_from_folder +from pathlib import Path +from jax import random + + + +def main(): + # Dataset save folder + save_folder = Path(r'C:\data\XrayTransform') + # Dataset of multiple random stl files from: https://ten-thousand-models.appspot.com/ + load_folder = Path(r'C:\Users\swittl\Downloads\Thingi10K\Thingi10K\raw_meshes') + + number_of_samples = len(list(save_folder.glob('*.stl'))) + + for i in range(number_of_samples, number_of_samples + 100): + # start from number of samples to generate diffrent rng keys!!! + key = random.PRNGKey(i) + generate_atlas_from_folder( + load_folder, key, save_folder) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index f21a4fc..35504c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,4 +23,6 @@ install_requires = jaxlib jax git+https://github.com/wittlsn/aRTist-PythonLib - scipy \ No newline at end of file + scipy + dm_pix + xraylib \ No newline at end of file -- GitLab