Skip to content
Snippets Groups Projects
Commit 2274c128 authored by Simon Wittl's avatar Simon Wittl
Browse files

added data augmentation and

parent 21a6bd9f
No related branches found
No related tags found
No related merge requests found
Showing
with 258 additions and 9 deletions
import h5py
from jax import numpy as jnp
from ..common.mesh_object import MeshObject
from CyXTraX.common.mesh_object import MeshObject
import json
from pathlib import Path
......@@ -13,11 +13,12 @@ def load_atlas(load_path: Path
) -> tuple[jnp.ndarray, jnp.ndarray, list[MeshObject]]:
with h5py.File(load_path, 'r') as f:
maps = f['/maps'][:]
points = f['/positions'][:]
maps = jnp.array(f['/maps'][:])
points = jnp.array(f['/positions'][:])
mesh_object_str = f.attrs['mesh_list']
mesh_object_dict = json.loads(mesh_object_str)
mesh_object_list = list(map(from_dict, mesh_object_dict))
return maps, points, mesh_object_list
......@@ -4,7 +4,7 @@ import numpy as np
import os
from scipy.spatial.transform import Rotation
import importlib.resources
from time import sleep
from CyXTraX.common.mesh_object import MeshObject
# !!!!!!!!!!!!!!!!!!!
......@@ -21,6 +21,8 @@ from CyXTraX.common.mesh_object import MeshObject
# The Script assumes this geometry and only moves all the objects!
# Preview must be reseted!!!
GLOBAL_COUNTER = 0
class CylindricalProjection:
def __init__(self, api: API = API()) -> None:
self.api = api
......@@ -55,11 +57,16 @@ class CylindricalProjection:
self.api.rc.send(f'::XDetector::SetDownCurvedView;')
def compute_projection(self, position: np.ndarray, temp_file_path: Path = Path('temp.tiff'), output_full_ray_projection: bool = True) -> np.ndarray:
def compute_projection(self, position: np.ndarray, temp_file_path: Path = Path(r'C:\data'), output_full_ray_projection: bool = True) -> np.ndarray:
global GLOBAL_COUNTER
self.translate(position)
self.api.save_image(temp_file_path, save_projection_geometry=False, save_mode=SAVEMODES.FLOAT_TIFF)
image = utility.load_projection(temp_file_path, load_projection_geometry=False)[0]
os.remove(temp_file_path)
temp_file = temp_file_path / f'temp_{GLOBAL_COUNTER:01}.tiff'
GLOBAL_COUNTER += 1
GLOBAL_COUNTER = GLOBAL_COUNTER % 10
self.api.save_image(temp_file, save_projection_geometry=False, save_mode=SAVEMODES.FLOAT_TIFF)
image = utility.load_projection(temp_file, load_projection_geometry=False)[0]
os.remove(temp_file)
if output_full_ray_projection:
return self.convert_rays(image)
......
from .augemtation_parameters import AugemntationParameter
from .noise import add_white_noise, add_gaussian_blur
from .xray import add_material
\ No newline at end of file
from dataclasses import dataclass
from jax import random
@dataclass
class AugemntationParameter:
augmentation_module: str
augmentation_function: str
function_parameter: dict
input_parameter: dict
\ No newline at end of file
from .augemtation_parameters import AugemntationParameter
import dm_pix as pix
from jax import numpy as jnp, random
def add_white_noise(maps: jnp.ndarray, mu: jnp.ndarray, sigma: jnp.ndarray, key: random.PRNGKey
) -> tuple[jnp.ndarray, AugemntationParameter, random.PRNGKey]:
key2use, key4next = random.split(key)
function_kwargs = dict(key=key2use, shape=maps.shape, dtype=maps.dtype)
input_kwargs = dict(mu=mu, sigma=sigma)
maps = maps.at[:].add(random.normal(**function_kwargs) * sigma + mu)
parameters = AugemntationParameter(
add_gaussian_blur.__globals__['__name__'],
add_gaussian_blur.__name__,
function_kwargs,
input_kwargs)
return maps, parameters, key4next
def add_gaussian_blur(maps: jnp.ndarray, sigma: jnp.ndarray, kernel_size: jnp.ndarray, key: random.PRNGKey
) -> tuple[jnp.ndarray, AugemntationParameter, random.PRNGKey]:
key4next = key
function_kwargs = dict(sigma=sigma, kernel_size=kernel_size)
input_kwargs = dict(kernel_size=kernel_size, sigma=sigma)
maps = pix.gaussian_blur(maps, **function_kwargs)
parameters = AugemntationParameter(
add_gaussian_blur.__globals__['__name__'],
add_gaussian_blur.__name__,
function_kwargs,
input_kwargs)
return maps, parameters, key4next
from .augemtation_parameters import AugemntationParameter
import xraylib
import numpy as np
from jax import numpy as jnp, random
def get_attenuation_coefficient_monochrom(element: str, energy_keV: float) -> float:
element = xraylib.SymbolToAtomicNumber("C")
return xraylib.CS_Total(element, energy_keV)
def get_attenuation_coefficient_polychrom(element: str, energy_keV: float, bins: int) -> float:
energies = jnp.linspace(0, energy_keV, bins, endpoint=True)
for i in range(1, bins):
energies = energies.at[i].set(get_attenuation_coefficient_monochrom(element, float(energies[i])))
return energies.sum()
def add_material(maps: jnp.ndarray, materilal: str, energy_keV: float, bins: int, key: random.PRNGKey):
key4next = key
function_kwargs = dict()
input_kwargs = dict(element=materilal, energy_keV=energy_keV, bins=bins)
attenuation_coefficient = get_attenuation_coefficient_polychrom(**input_kwargs)
parameters = AugemntationParameter(
add_material.__globals__['__name__'],
add_material.__name__,
function_kwargs,
input_kwargs)
maps = maps.at[:].multiply(attenuation_coefficient)
return maps, parameters, key4next
if __name__ == '__main__':
mu = get_attenuation_coefficient_polychrom('c', 100, 10)
print(mu)
\ No newline at end of file
from .merge import merge_atlas_and_dict, merge_atlases, merge_lists
from .pipeline import generate_atlas_from_folder
\ No newline at end of file
from jax import numpy as jnp
from CyXTraX.common.mesh_object import MeshObject
def merge_atlases(carry_atlas: jnp.ndarray, atlas: jnp.ndarray) -> jnp.ndarray:
return carry_atlas.at[:].add(atlas) / 2.
def merge_lists(carry_list: list, x: list) -> list:
return carry_list.extend(x)
def merge_atlas_and_dict(atlas_one, mesh_list_one, atlas_two, mesh_list_two) -> tuple[jnp.ndarray, list]:
return merge_atlases(atlas_one, atlas_two), merge_lists(mesh_list_one, mesh_list_two)
from pathlib import Path
from CyXTraX.simulation import CylindricalProjection, load_mesh
from CyXTraX.common import MeshObject
from CyXTraX.io import record_atlas
from jax import random, numpy as jnp
from jax.scipy.spatial.transform import Rotation
def generate_atlas_from_folder(
folder: Path, key: random.PRNGKey,
save_folder: Path,
map_bounding_box: tuple[float]=(-50., 50.), number_of_maps: int = 4):
randint_key, postion_key, orientation_key = random.split(key, 3)
stl_files = list(folder.glob('*.stl'))
number_of_files = len(stl_files)
item_id = random.randint(randint_key, (1,), 0, number_of_files, dtype=jnp.int32)
random_euler = random.uniform(orientation_key, (3,), minval=-180, maxval=180)
rot: Rotation = Rotation.from_euler('xyz', random_euler)
stl_path: Path = stl_files[item_id[0]]
mesh_item = MeshObject(
stl_path,
random.truncated_normal(postion_key, -50, 50, (3,)) + random.uniform(postion_key, (3,), minval=-25, maxval=25),
rot.as_quat())
mesh_list = [mesh_item]
cycl_proj = CylindricalProjection()
load_mesh(cycl_proj, mesh_list)
record_atlas(cycl_proj, mesh_list,
stl_path.stem, save_folder,
map_bounding_box, number_of_maps)
from .plt import save_map_plot_with_index
\ No newline at end of file
from CyXTraX.io import load_atlas
from pathlib import Path
def atlas_gif(load_path: Path):
maps, map_position, meshes = load_atlas(load_path)
\ No newline at end of file
from matplotlib import pyplot as plt
from jax import numpy as jnp
from pathlib import Path
def save_map_plot_with_index(maps: jnp.ndarray, index: int, save_path: Path, figsize=(8, 8),
title: str = None):
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.imshow(maps[:, :, index])
ax.set_xlabel(r'$z$ / mm')
ax.set_ylabel(r'$\alpha$ / rad')
ax.set_xticks([0, int(maps.shape[0] // 2), int(maps.shape[0])])
ax.set_xticklabels([r'$-3141.5$', r'$0.$', r'$3141.5$'])
ax.set_yticks([0, int(maps.shape[1] // 2), int(maps.shape[1])])
ax.set_yticklabels([r'$-\pi$', r'$0.$', r'$\pi$'])
if title is not None:
ax.set_title(title)
plt.tight_layout()
plt.savefig(save_path)
plt.close(fig)
\ No newline at end of file
from CyXTraX.io import load_atlas
from CyXTraX.simulation.model import add_white_noise, add_gaussian_blur, add_material
from CyXTraX.util.visualisation import save_map_plot_with_index
from CyXTraX.util.datasets import merge_atlases
from pathlib import Path
from jax import random
# Just get the temp folder to store the atlas
FOLDER = Path(__file__)
TEMP_FOLDER = FOLDER.parent.parent / 'temp'
def main():
files = list(TEMP_FOLDER.glob('*.h5'))
if not files:
raise FileNotFoundError('No .h5 files in temp folder. Run Example 02!')
map_index = 3
maps, points, mesh_object_list = load_atlas(files[-1])
random_key = random.PRNGKey(42)
save_map_plot_with_index(maps, map_index, TEMP_FOLDER / 'blank_map.png')
map_with_white_noise, white_noise_parameter, random_key = add_white_noise(maps, 5.3, 10.1, random_key)
print(white_noise_parameter)
save_map_plot_with_index(
map_with_white_noise, map_index, TEMP_FOLDER / 'noise_map.png',
title='White Noise')
map_with_gaussian_blur, gaussian_blur_parameter, random_key = add_gaussian_blur(maps, 1.1, 3, random_key)
print(gaussian_blur_parameter)
save_map_plot_with_index(
map_with_gaussian_blur, map_index, TEMP_FOLDER / 'gaussian_map.png',
title='Gaussian Blur')
map_with_attenuation, attenuation_parameter, random_key = add_material(maps, 'c', 225, 30, random_key)
print(attenuation_parameter)
save_map_plot_with_index(
map_with_attenuation, map_index, TEMP_FOLDER / 'attenuation_c_map.png',
title='Attenuation C')
merged_map = merge_atlases(map_with_gaussian_blur, map_with_white_noise)
save_map_plot_with_index(
merged_map, map_index, TEMP_FOLDER / 'merged_map.png',
title='Merged Map: Gaussian + Noise')
if __name__ == '__main__':
main()
\ No newline at end of file
from CyXTraX.util.datasets import generate_atlas_from_folder
from pathlib import Path
from jax import random
def main():
# Dataset save folder
save_folder = Path(r'C:\data\XrayTransform')
# Dataset of multiple random stl files from: https://ten-thousand-models.appspot.com/
load_folder = Path(r'C:\Users\swittl\Downloads\Thingi10K\Thingi10K\raw_meshes')
number_of_samples = len(list(save_folder.glob('*.stl')))
for i in range(number_of_samples, number_of_samples + 100):
# start from number of samples to generate diffrent rng keys!!!
key = random.PRNGKey(i)
generate_atlas_from_folder(
load_folder, key, save_folder)
if __name__ == '__main__':
main()
\ No newline at end of file
......@@ -23,4 +23,6 @@ install_requires =
jaxlib
jax
git+https://github.com/wittlsn/aRTist-PythonLib
scipy
\ No newline at end of file
scipy
dm_pix
xraylib
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment