dir: add missing functions to .load() custom objs

apparently metrics are also required to be included here...
This commit is contained in:
Starbeamrainbowlabs 2023-03-09 19:26:57 +00:00
parent ad52ae9241
commit 96eae636ea
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -163,6 +163,10 @@ else:
model = tf.keras.models.load_model(PATH_CHECKPOINT, custom_objects={
# Tell Tensorflow about our custom layers so that it can deserialise models that use them
"LossCrossEntropyDice": LossCrossEntropyDice,
"metric_dice_coefficient": dice_coefficient,
"make_sensitivity": sensitivity,
"specificity": specificity,
"make_one_hot_mean_iou": mean_iou
})