dlr: fixup argmax first

This commit is contained in:
Starbeamrainbowlabs 2023-03-03 21:51:24 +00:00
parent 7453c607ed
commit 9b13e9ca5b
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 2 additions and 4 deletions

View file

@ -3,11 +3,10 @@ import math
import tensorflow as tf
def sensitivity(y_true, y_pred):
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
y_pred = tf.math.argmax(y_pred, axis=-1)
recall = tf.keras.metrics.Recall()
recall.update_state(y_true, y_pred)
return recall.result()

View file

@ -13,11 +13,10 @@ def specificity(y_pred, y_true):
Returns:
Specificity score
"""
y_pred = tf.math.argmax(y_pred, axis=-1)
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
y_pred = tf.math.argmax(y_pred, axis=-1)
neg_y_true = 1 - y_true
neg_y_pred = 1 - y_pred
fp = K.sum(neg_y_true * y_pred)