dlr: implement stage 2 WIP of opt regression support via rmse

....this is so stupid. There's a reason why diffusion models are a thing and we don't use image segmentation models for this!!
This commit is contained in:
Starbeamrainbowlabs 2024-12-20 19:36:10 +00:00
parent 58a7e22a4d
commit fe2f8b3821
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 62 additions and 25 deletions

View file

@ -234,47 +234,55 @@ def plot_metric(train, val, name, dir_output):
if PATH_CHECKPOINT is None:
loss_fn = None
if LOSS == "cross-entropy-dice":
loss_fn = LossCrossEntropyDice(log_cosh=DICE_LOG_COSH)
elif LOSS == "cross-entropy":
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
else:
raise Exception(f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice).")
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
loss=loss_fn,
metrics=[
metrics = [
"accuracy",
dice_coefficient,
mean_iou(),
sensitivity(), # How many true positives were accurately predicted
specificity # How many true negatives were accurately predicted?
],
specificity, # How many true negatives were accurately predicted?
]
if LOSS == "cross-entropy-dice":
loss_fn = LossCrossEntropyDice(log_cosh=DICE_LOG_COSH)
elif LOSS == "cross-entropy":
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
elif LOSS == "root-mean-squared-error":
loss_fn = tf.keras.metrics.RootMeanSquaredError()
metrics = [tf.keras.metrics.RootMeanSquaredError()] # Others don't make sense w/o this
else:
raise Exception(
f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice)."
)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
loss=loss_fn,
metrics=metrics,
steps_per_execution=STEPS_PER_EXECUTION,
jit_compile=JIT_COMPILE
jit_compile=JIT_COMPILE,
)
logger.info(">>> Beginning training")
history = model.fit(dataset_train,
history = model.fit(
dataset_train,
validation_data=dataset_validate,
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
epochs=EPOCHS,
callbacks=[
CallbackExtraValidation({ # `model,` removed 'ref apparently exists by default????? ehhhh...???
CallbackExtraValidation(
{ # `model,` removed 'ref apparently exists by default????? ehhhh...???
"test": dataset_test # Can be None because it handles that
}),
}
),
tf.keras.callbacks.CSVLogger(
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
separator="\t"
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), separator="\t"
),
CallbackCustomModelCheckpoint(
model_to_checkpoint=model,
filepath=os.path.join(
DIR_OUTPUT,
"checkpoints",
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5",
),
monitor="loss"
monitor="loss",
),
],
steps_per_epoch=STEPS_PER_EPOCH,
@ -283,12 +291,39 @@ if PATH_CHECKPOINT is None:
logger.info(">>> Training complete")
logger.info(">>> Plotting graphs")
plot_metric(history.history["loss"], history.history["val_loss"], "loss", DIR_OUTPUT)
plot_metric(history.history["accuracy"], history.history["val_accuracy"], "accuracy", DIR_OUTPUT)
plot_metric(history.history["metric_dice_coefficient"], history.history["val_metric_dice_coefficient"], "dice", DIR_OUTPUT)
plot_metric(history.history["one_hot_mean_iou"], history.history["val_one_hot_mean_iou"], "mean iou", DIR_OUTPUT)
plot_metric(history.history["sensitivity"], history.history["val_sensitivity"], "sensitivity", DIR_OUTPUT)
plot_metric(history.history["specificity"], history.history["val_specificity"], "specificity", DIR_OUTPUT)
plot_metric(
history.history["loss"], history.history["val_loss"], "loss", DIR_OUTPUT
)
plot_metric(
history.history["accuracy"],
history.history["val_accuracy"],
"accuracy",
DIR_OUTPUT,
)
plot_metric(
history.history["metric_dice_coefficient"],
history.history["val_metric_dice_coefficient"],
"dice",
DIR_OUTPUT,
)
plot_metric(
history.history["one_hot_mean_iou"],
history.history["val_one_hot_mean_iou"],
"mean iou",
DIR_OUTPUT,
)
plot_metric(
history.history["sensitivity"],
history.history["val_sensitivity"],
"sensitivity",
DIR_OUTPUT,
)
plot_metric(
history.history["specificity"],
history.history["val_specificity"],
"specificity",
DIR_OUTPUT,
)
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████

View file

@ -114,6 +114,8 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
if water_threshold is not None: # if water_threshold=None, then regression mode
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
# BUG it may be a problem we're [height, width, channel] here rather than [height, width], depending on how dlr works
else:
water = tf.expand_dims(water, axis=-1) # Stack to have a channel, since if water_threshold=None then we would end up with [height, width] instead of [height, width, channel] otherwise
if do_remove_isolated_pixels:
water = remove_isolated_pixels(water)