Source code for nondefaced_detector.helpers.utils

"""Helper functions for nondefaced-detector."""


import binascii
import os

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

from tensorflow.python.client import device_lib


[docs]def is_gz_file(filepath): if os.path.splitext(filepath)[1] == ".gz": with open(filepath, "rb") as test_f: return binascii.hexlify(test_f.read(2)) == b"1f8b" return False
[docs]def save_vol(save_path, tensor_3d, affine): """ save_path: path to write the volume to tensor_3d: 3D volume which needs to be saved affine: image orientation, translation """ directory = os.path.dirname(save_path) if not os.path.exists(directory): os.makedirs(directory) volume = nib.Nifti1Image(tensor_3d, affine) volume.set_data_dtype(np.float32) nib.save(volume, save_path)
[docs]def load_vol(load_path): """ load_path: volume path to load return: volume: loaded 3D volume affine: affine data specific to the volume """ if not os.path.exists(load_path): raise ValueError("path doesn't exist") nib_vol = nib.load(load_path) vol_data = nib_vol.get_data() vol_affine = nib_vol.affine return np.array(vol_data), vol_affine, vol_data.shape
def grid_to_single(image_batch, label_image=False): shape = image_batch.shape # print (shape) n = int(np.sqrt(shape[0])) x = shape[1] y = shape[2] if not label_image: img_array = np.zeros((x * n, y * n, 3)) else: img_array = np.zeros((x * n, y * n)) # print (img_array.shape) idx = 0 for i in range(n): for j in range(n): # print (i*x, (i+1)*x, j*y, (j+1)*y, idx) # imshow(image_batch[idx]) if not label_image: img_array[i * x : (i + 1) * x, j * y : (j + 1) * y, :] = image_batch[ idx ] else: img_array[i * x : (i + 1) * x, j * y : (j + 1) * y] = image_batch[idx] idx = idx + 1 return img_array
[docs]def imshow(*args, **kwargs): """Handy function to show multiple plots in on row, possibly with different cmaps and titles Usage: imshow(img1, title="myPlot") imshow(img1,img2, title=['title1','title2']) imshow(img1,img2, cmap='hot') imshow(img1,img2,cmap=['gray','Blues']) """ cmap = kwargs.get("cmap", "gray") title = kwargs.get("title", "") axis_off = kwargs.get("axis_off", "") if len(args) == 0: raise ValueError("No images given to imshow") elif len(args) == 1: plt.title(title) plt.imshow(args[0], interpolation="none") else: n = len(args) if type(cmap) == str: cmap = [cmap] * n if type(title) == str: title = [title] * n plt.figure(figsize=(n * 5, 10)) for i in range(n): plt.subplot(1, n, i + 1) plt.title(title[i]) plt.imshow(args[i], cmap[i]) if axis_off: plt.axis("off") plt.show()
def get_available_gpus(): """ Get the total number of GPUs available """ local_device_protos = device_lib.list_local_devices() return [x.name for x in local_device_protos if x.device_type == "GPU"] def schedule_steps(epoch, steps): for step in steps: if step[1] > epoch: print("Setting learning rate to {}".format(step[0])) return step[0] print("Setting learning rate to {}".format(steps[-1][0])) return steps[-1][0]