y_true is one-hot, convert to sparse

This commit is contained in:
Starbeamrainbowlabs 2023-03-03 22:20:11 +00:00
parent c7b577ab29
commit bc734a29c6
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
3 changed files with 3 additions and 0 deletions

View file

@ -14,6 +14,7 @@ def one_hot_mean_iou(y_true, y_pred, classes=2):
"""
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.math.argmax(y_true, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)

View file

@ -4,6 +4,7 @@ import tensorflow as tf
def sensitivity(y_true, y_pred):
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.math.argmax(y_true, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)

View file

@ -14,6 +14,7 @@ def specificity(y_pred, y_true):
Specificity score
"""
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.math.argmax(y_true, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)