Source code for nondefaced_detector.preprocess

"""Script to preprocess volumes"""

import functools
import os
import tempfile

import multiprocessing as mp
import tensorflow as tf

from tqdm import tqdm

from nondefaced_detector.preprocessing.conform import conform_data
from nondefaced_detector.helpers import utils
from nondefaced_detector.preprocessing.normalization import clip, normalize, standardize

AUTOTUNE = tf.data.experimental.AUTOTUNE


[docs]def preprocess( vol_path, conform_volume_to=(128, 128, 128), conform_zooms=(2.0, 2.0, 2.0), save_path=None, with_label=False, ): """Preprocess input volumes before prediction. Parameters ---------- vol_path : str - Path or tuple of length 2 (str - Path, int) The path to the input volume. If the `with_label` flag is True, the vol_path is required to be a tuple of size 2 - (vol_path, label) conform_volume_to : tuple of length 3, optional, default=(128 128, 128) The shape the volume will be conformed to. Note: The pretrained model was trained using the conform size of (128, 128, 128) and assumes the volume shape as such. save_path : str - Path, optional The path where the output volume is saved. If none is provided, the output volume will be saved under `vol_path/preprocessed` with_label: bool, optional If True, the input vol_path is required to be a tuple of 2 (vol_path, label) Returns ------- str - Path Path to the where the preprocessed volume is stored. (Path, label) if with_label is True. """ try: vpath = vol_path if with_label: if len(vol_path) != 2: raise ValueError( "The vol_path must have length of 2 when with_label=True" ) vpath, label = vol_path spath = os.path.join(os.path.dirname(vpath), "preprocessed") if save_path: spath = os.path.join(save_path, "preprocessed") os.makedirs(spath, exist_ok=True) volume, affine, _ = utils.load_vol(vpath) # Prepocessing volume = clip(volume, q=90) volume = standardize(volume) volume = normalize(volume) tmp_preprocess_vol = tempfile.NamedTemporaryFile( suffix=".nii.gz", delete=True, dir=spath ) utils.save_vol(tmp_preprocess_vol.name, volume, affine) tmp_conform_vol = os.path.join(spath, os.path.basename(vpath)) conform_data( tmp_preprocess_vol.name, out_file=tmp_conform_vol, out_size=conform_volume_to, out_zooms=conform_zooms, ) tmp_preprocess_vol.close() if with_label: return (tmp_conform_vol, label) return tmp_conform_vol except Exception as e: print(e) return
[docs]def preprocess_parallel( volume_filepaths, num_parallel_calls=AUTOTUNE, conform_volume_to=(128, 128, 128), conform_zooms=(2.0, 2.0, 2.0), save_path=None, with_label=True, ): """Preprocess multiple input volumes before prediction in parallel. Parameters ---------- volume_filepaths: list of str - Path or list of tuple of length 2 [(str - Path, int), ...] A list of paths to the input volumes. If the `with_label` flag is True, the volume_filepaths is required to be a list of tuples of size 2 - (volume_filepath, label) num_parallel_calls: int Number of parallel calls to make for preprocessing. conform_volume_to: tuple of length 3, optional, default=(128 128, 128) The shape the volume will be conformed to. Note: The pretrained model was trained using the conform size of (128, 128, 128) and assumes the volume shape as such. conform_zooms: tuple of size 3, optional, default=(2.0, 2.0, 2.0) The zoom of the resampled output. save_path: str - Path, optional The path where the output volume is saved. If none is provided, the output volume will be saved under `volume_filepath/preprocessed` with_label: bool, optional If True, each volume_filepath is required to be a tuple of 2 (volume_filepath, label) Returns ------- list of str List of str paths to the where each preprocessed volume is stored. [(Path, label), ...] if with_label is True. """ try: if with_label: for pair in volume_filepaths: if len(pair) != 2: raise ValueError( "all items in 'volume_filepaths' must have length of 2, but" " found at least one item with lenght != 2." ) map_fn = functools.partial( preprocess, conform_volume_to=conform_volume_to, conform_zooms=conform_zooms, save_path=save_path, with_label=with_label, ) if num_parallel_calls is None: # Get number of eligible CPUs. num_parallel_calls = len(os.sched_getaffinity(0)) print("Preprocessing {} examples".format(len(volume_filepaths))) outputs = [] if num_parallel_calls == 1: for vf in tqdm(volume_filepaths, total=len(volume_filepaths)): result = map_fn(vf) outputs.append(result) else: pool = mp.Pool(num_parallel_calls) for result in tqdm( pool.imap(func=map_fn, iterable=volume_filepaths), total=len(volume_filepaths), ): outputs.append(result) return outputs except Exception as e: print(e) return
def cleanup_files(*args): """ Function to remove temp files created during preprocessing.""" for p in args: if os.path.exists(p): os.remove(p)