This commit is contained in:
Starbeamrainbowlabs 2023-03-03 22:45:34 +00:00
parent 0734201107
commit c909cfd3d1
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -3,9 +3,9 @@ import math
import tensorflow as tf
def make_one_hot_mean_iou():
def make_one_hot_mean_iou(classes=2):
iou = tf.keras.metrics.MeanIoU(num_classes=classes)
def one_hot_mean_iou(y_true, y_pred, classes=2):
def one_hot_mean_iou(y_true, y_pred, ):
"""Compute the mean IoU for one-hot tensors.
Args:
y_true (tf.Tensor): The ground truth label.