Source code for nondefaced_detector.inference

"""Standalone inference script for held-out test dataset."""

import argparse
import os

import tensorflow as tf
from tensorflow.keras import metrics
from tensorflow.keras.optimizers import Adam

from nondefaced_detector.models.model import CombinedClassifier
from nondefaced_detector.dataloaders.dataset import get_dataset


[docs]def inference(tfrecords_path, weights_path, wts_root): """ Inference function to reproduce original model scores. This script can be run as a standalone using python inference.py. For more information try: `python inference.py -h` Parameters ---------- tfrecords_path: str The path to directory containing preprocessed tfrecords. weights_path: str The path to the combined model weights. A copy of the weights can be found here: https://gin.g-node.org/shashankbansal56/nondefaced-detector-reproducibility/src/master/pretrained_weights/combined wts_root: str The path to the root directory of all the model weights. A copy of the weights can be found here: https://gin.g-node.org/shashankbansal56/nondefaced-detector-reproducibility/src/master/pretrained_weights """ model = CombinedClassifier( input_shape=(128, 128), dropout=0.4, wts_root=wts_root, trainable=False ) model.load_weights(os.path.abspath(weights_path)) model.trainable = False dataset_test = get_dataset( file_pattern=os.path.join(tfrecords_path, "data-test_*"), n_classes=2, batch_size=16, volume_shape=(128, 128, 128), plane="combined", mode="test", ) METRICS = [ metrics.BinaryAccuracy(name="accuracy"), metrics.Precision(name="precision"), metrics.Recall(name="recall"), metrics.AUC(name="auc"), ] model.compile( loss=tf.keras.losses.binary_crossentropy, optimizer=Adam(learning_rate=1e-3), metrics=METRICS, ) model.evaluate(dataset_test)
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "tfrecords", metavar="tfrecords_path", help="Path to tfrecords." ) parser.add_argument( "model_path", metavar="model_path", help="Path to pretrained model weights." ) args = parser.parse_args() tfrecords_path = args.tfrecords model_path = args.model_path combined_path = os.path.join(model_path, "combined/best-wts.h5") inference(tfrecords_path, combined_path, model_path)