Source code for nondefaced_detector.models.model

"""Model definition for nondefaced-detector."""


import os
import tensorflow as tf

from tensorflow.keras import layers, models


[docs]def ConvBNrelu(x, filters=32, kernel=3, strides=1, padding="same"): """A layer block of one convolutional, one batch normalization, and one non-linear activation sequence. Parameters ---------- x: :obj:`tf.Tensor` of rank 4+ The input keras tensor object to instantiate a keras model filters: int, optional, default=32 The dimensionality of the output space (i.e. the number of output filters in the convolution). kernel: int, optional, default=32 An integer or tuple/list of 2 integers, specifying the height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: int Specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. padding: one of "valid" or "same" (case-insensitive). "valid" means no padding. "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input. Returns ------- :obj:`tf.Tensor` A tensor of rank 4+. """ x = layers.Conv2D(filters, kernel, strides=strides, padding=padding)(x) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) return x
[docs]def TruncatedSubmodel(input_layer): """The TruncatedSubmodel trained in Step 1 of the model. Parameters ---------- input_layer: tf.keras.Input The input keras tensor object to instantiate a keras model Returns ------- :obj:`tf.Tensor` A flattened truncated network created from 3 sequential ConvBNRelu layer blocks joined by a MaxPooling layer. """ conv1 = ConvBNrelu(input_layer, filters=8, kernel=3, strides=1, padding="same") conv1 = ConvBNrelu(conv1, filters=8, kernel=3, strides=1, padding="same") conv1 = layers.MaxPooling2D()(conv1) conv2 = ConvBNrelu(conv1, filters=16, kernel=3, strides=1, padding="same") conv2 = ConvBNrelu(conv2, filters=16, kernel=3, strides=1, padding="same") conv2 = layers.MaxPooling2D()(conv2) conv3 = ConvBNrelu(conv2, filters=32, kernel=3, strides=1, padding="same") conv3 = ConvBNrelu(conv3, filters=32, kernel=3, strides=1, padding="same") conv3 = layers.MaxPooling2D()(conv3) out = layers.Flatten()(conv3) return out
[docs]def ClassifierHead(layer, dropout): """The final block of the model Parameters ---------- layer: N-D tensor with shape: (batch_size, ..., input_dim) The flattened out feature layer output from the Submodels dropout: float Float between 0 and 1. Fraction of the input units to drop. Returns ------- :obj:`tf.Tensor` N-D tensor with shape: (batch_size, ..., units) """ out = layers.Dense(256, activation="relu")(layer) out = layers.Dropout(dropout)(out) out = layers.Dense(1, activation="sigmoid", name="output_node")(out) return out
[docs]def Submodel( root_path, input_shape=(32, 32), dropout=0.4, name="axial", weights="axial", include_top=True, trainable=True, ): """3 identical submodel blocks are used to train on spatial information from all three axes (axial, coronal, sagittal) separately. Parameters ---------- root_path: str, Path Root directory for storing the weights. input_shape: tuple of ints, default=(32, 32) The shape of the input image. dropout: float, optional, default=0.4 Float between 0 and 1. Fraction of the input units to drop. name: str Name of the submodel. weights: str Name of the folder to store the weights for the submodel. include_top: bool, default=True If True, the the model includes the ClassiferHead block at the end. trainable: bool, default=True If True, the model is set to be trainable else the model layers are frozen. Returns ------- `tf.keras.Model` Returns a `tf.keras.Model` object with features. """ input_layer = layers.Input(shape=input_shape + (1,), name=name) features = TruncatedSubmodel(input_layer) if not include_top: model = models.Model(input_layer, features) else: classifier = ClassifierHead(features, dropout) model = models.Model(input_layer, classifier) if weights: weights_pth = os.path.join(root_path, name, "best-wts.h5") model.load_weights(weights_pth) if not trainable: for layer in model.layers: layer.trainable = False return model
[docs]def CombinedClassifier( input_shape=(32, 32), dropout=0.4, wts_root=None, trainable=False, shared=False ): """The final block of the model that combines features and outputs a real-valued probability using the sigmoid function. Parameters ---------- input_shape: tuple of ints, default=(32, 32) The shape of the input image. dropout: float, optional, default=0.4 Float between 0 and 1. Fraction of the input units to drop. trainable: bool, default=True If True, the model is set to be trainable else the model layers are frozen. shared: bool, default=False Returns ------- """ axial_features = Submodel( input_shape, dropout, name="axial", weights=None, include_top=False, root_path=wts_root, ) if not shared: sagittal_features = Submodel( input_shape, dropout, name="sagittal", weights=None, include_top=False, root_path=wts_root, ) coronal_features = Submodel( input_shape, dropout, name="coronal", weights=None, include_top=False, root_path=wts_root, ) input_features = [ axial_features.inputs, coronal_features.inputs, sagittal_features.inputs, ] merge_features = [ axial_features.outputs[0], sagittal_features.outputs[0], coronal_features.outputs[0], ] else: p1 = layers.Input(shape=input_shape + (1,), name="plane1") p2 = layers.Input(shape=input_shape + (1,), name="plane2") p3 = layers.Input(shape=input_shape + (1,), name="plane3") merge_features = [axial_features(p1), axial_features(p2), axial_features(p3)] input_features = [p1, p2, p3] add_features = layers.Add()(merge_features) prob = ClassifierHead(add_features, dropout) model = models.Model(inputs=input_features, outputs=prob) if not trainable: assert not (wts_root == None) axial_model = Submodel( input_shape, dropout, name="axial", weights="axial", include_top=True, root_path=wts_root, ) coronal_model = Submodel( input_shape, dropout, name="coronal", weights="coronal", include_top=True, root_path=wts_root, ) sagittal_model = Submodel( input_shape, dropout, name="sagittal", weights="sagittal", include_top=True, root_path=wts_root, ) for ii in range(1, len(axial_features.layers)): model.layers[3 * ii].set_weights(axial_model.layers[ii].get_weights()) model.layers[3 * ii + 1].set_weights( sagittal_model.layers[ii].get_weights() ) model.layers[3 * ii + 2].set_weights(coronal_model.layers[ii].get_weights()) model.layers[3 * ii].trainable = False model.layers[3 * ii + 1].trainable = False model.layers[3 * ii + 1].trainable = False return model