Source code for nondefaced_detector.training.training

"""Training module for nondefaced-detector."""


import os
import math

import tensorflow as tf
import pandas as pd
import numpy as np

from sklearn.utils import class_weight
from tensorflow.keras import backend as K
from tensorflow.keras import metrics
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping

from nondefaced_detector.models import model as _model
from nondefaced_detector.dataloaders.dataset import get_dataset


def scheduler(epoch):
    if epoch < 3:
        return 0.001
    else:
        return 0.001 * tf.math.exp(0.1 * (10 - epoch))


[docs]def train( csv_path, model_save_path, tfrecords_path, volume_shape=(128, 128, 128), image_size=(128, 128), dropout=0.2, batch_size=16, n_classes=2, n_epochs=15, mode="CV", ): """Train a model. Parameters ---------- csv_path: str - Path Path to the csv file containing training volume paths, labels (X, Y). model_save_path: str - Path Path to where the save model and model weights. tfrecords_path: str - Path Path to preprocessed training tfrecords. volume_shape: tuple of size 3, optional, default=(128, 128, 128) The shape of the preprocessed volumes. image_size: tuple of size 2, optional, default=(128, 128) The shape of a 2D slice along each volume axis. dropout: float, optional, default=0.4 Float between 0 and 1. Fraction of the input units to drop. batch_size: int, optional, default=16 No. of training examples utilized in each iteration. n_classes: int, optional, default=2 No. of unique classes to train the model on. Default assumption is a binary classifier. n_epochs: int, optional, default=15 No. of complete passes through the training dataset. mode: str, optional, default=15 One of "CV" or "full". Indicates the type of training to perform. Returns ------- `tf.keras.callbacks.History` A History object that records several metrics such as training/validation loss/metrics at successive epochs. """ train_csv_path = os.path.join(csv_path, "training.csv") train_paths = pd.read_csv(train_csv_path)["X"].values train_labels = pd.read_csv(train_csv_path)["Y"].values if mode == "CV": valid_csv_path = os.path.join(csv_path, "validation.csv") valid_paths = pd.read_csv(valid_csv_path)["X"].values # valid_labels = pd.read_csv(valid_csv_path)["Y"].values weights = class_weight.compute_class_weight( "balanced", np.unique(train_labels), train_labels ) weights = dict(enumerate(weights)) planes = ["axial", "coronal", "sagittal", "combined"] global_batch_size = batch_size os.makedirs(model_save_path, exist_ok=True) cp_save_path = os.path.join(model_save_path, "weights") logdir_path = os.path.join(model_save_path, "tb_logs") metrics_path = os.path.join(model_save_path, "metrics") os.makedirs(metrics_path, exist_ok=True) for plane in planes: logdir = os.path.join(logdir_path, plane) os.makedirs(logdir, exist_ok=True) tbCallback = TensorBoard(log_dir=logdir) os.makedirs(os.path.join(cp_save_path, plane), exist_ok=True) model_checkpoint = ModelCheckpoint( os.path.join(cp_save_path, plane, "best-wts.h5"), monitor="val_loss", save_weights_only=True, mode="min", ) if not plane == "combined": lr = 1e-3 model = _model.Submodel( input_shape=image_size, dropout=dropout, name=plane, include_top=True, weights=None, ) else: lr = 5e-4 model = _model.CombinedClassifier( input_shape=image_size, dropout=dropout, trainable=True, wts_root=cp_save_path, ) print("Submodel: ", plane) METRICS = [ metrics.TruePositives(name="tp"), metrics.FalsePositives(name="fp"), metrics.TrueNegatives(name="tn"), metrics.FalseNegatives(name="fn"), metrics.BinaryAccuracy(name="accuracy"), metrics.Precision(name="precision"), metrics.Recall(name="recall"), metrics.AUC(name="auc"), ] model.compile( loss=tf.keras.losses.binary_crossentropy, optimizer=Adam(learning_rate=lr), metrics=METRICS, ) dataset_train = get_dataset( file_pattern=os.path.join(tfrecords_path, "data-train_*"), n_classes=n_classes, batch_size=global_batch_size, volume_shape=volume_shape, plane=plane, shuffle_buffer_size=global_batch_size, ) steps_per_epoch = math.ceil(len(train_paths) / batch_size) if mode == "CV": earlystopping = EarlyStopping(monitor="val_loss", patience=3) dataset_valid = get_dataset( file_pattern=os.path.join(tfrecords_path, "data-valid_*"), n_classes=n_classes, batch_size=global_batch_size, volume_shape=volume_shape, plane=plane, shuffle_buffer_size=global_batch_size, ) validation_steps = math.ceil(len(valid_paths) / batch_size) history = model.fit( dataset_train, epochs=n_epochs, steps_per_epoch=steps_per_epoch, validation_data=dataset_valid, validation_steps=validation_steps, callbacks=[tbCallback, model_checkpoint, earlystopping], class_weight=weights, ) hist_df = pd.DataFrame(history.history) else: earlystopping = EarlyStopping(monitor="loss", patience=3) print(model.summary()) print("Steps/Epoch: ", steps_per_epoch) history = model.fit( dataset_train, epochs=n_epochs, steps_per_epoch=steps_per_epoch, callbacks=[tbCallback, model_checkpoint, earlystopping], class_weight=weights, ) hist_df = pd.DataFrame(history.history) jsonfile = os.path.join(metrics_path, plane + ".json") with open(jsonfile, mode="w") as f: hist_df.to_json(f) return history