This commit is contained in:
Starbeamrainbowlabs 2023-03-03 21:54:45 +00:00
parent 9b13e9ca5b
commit 6ffda40d48
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -23,8 +23,8 @@ def dice_coefficient(y_true, y_pred):
def metric_dice_coefficient(y_true, y_pred):
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
y_pred = tf.math.argmax(y_pred, axis=-1)
return dice_coefficient(y_true, y_pred)