mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 06:05:01 +00:00
dlr: implement stage 2 WIP of opt regression support via rmse
....this is so stupid. There's a reason why diffusion models are a thing and we don't use image segmentation models for this!!
This commit is contained in:
parent
58a7e22a4d
commit
fe2f8b3821
2 changed files with 62 additions and 25 deletions
|
@ -234,47 +234,55 @@ def plot_metric(train, val, name, dir_output):
|
|||
|
||||
if PATH_CHECKPOINT is None:
|
||||
loss_fn = None
|
||||
metrics = [
|
||||
"accuracy",
|
||||
dice_coefficient,
|
||||
mean_iou(),
|
||||
sensitivity(), # How many true positives were accurately predicted
|
||||
specificity, # How many true negatives were accurately predicted?
|
||||
]
|
||||
if LOSS == "cross-entropy-dice":
|
||||
loss_fn = LossCrossEntropyDice(log_cosh=DICE_LOG_COSH)
|
||||
elif LOSS == "cross-entropy":
|
||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
elif LOSS == "root-mean-squared-error":
|
||||
loss_fn = tf.keras.metrics.RootMeanSquaredError()
|
||||
metrics = [tf.keras.metrics.RootMeanSquaredError()] # Others don't make sense w/o this
|
||||
else:
|
||||
raise Exception(f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice).")
|
||||
|
||||
raise Exception(
|
||||
f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice)."
|
||||
)
|
||||
|
||||
model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
|
||||
loss=loss_fn,
|
||||
metrics=[
|
||||
"accuracy",
|
||||
dice_coefficient,
|
||||
mean_iou(),
|
||||
sensitivity(), # How many true positives were accurately predicted
|
||||
specificity # How many true negatives were accurately predicted?
|
||||
],
|
||||
metrics=metrics,
|
||||
steps_per_execution=STEPS_PER_EXECUTION,
|
||||
jit_compile=JIT_COMPILE
|
||||
jit_compile=JIT_COMPILE,
|
||||
)
|
||||
logger.info(">>> Beginning training")
|
||||
history = model.fit(dataset_train,
|
||||
history = model.fit(
|
||||
dataset_train,
|
||||
validation_data=dataset_validate,
|
||||
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
|
||||
epochs=EPOCHS,
|
||||
callbacks=[
|
||||
CallbackExtraValidation({ # `model,` removed 'ref apparently exists by default????? ehhhh...???
|
||||
"test": dataset_test # Can be None because it handles that
|
||||
}),
|
||||
CallbackExtraValidation(
|
||||
{ # `model,` removed 'ref apparently exists by default????? ehhhh...???
|
||||
"test": dataset_test # Can be None because it handles that
|
||||
}
|
||||
),
|
||||
tf.keras.callbacks.CSVLogger(
|
||||
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
|
||||
separator="\t"
|
||||
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), separator="\t"
|
||||
),
|
||||
CallbackCustomModelCheckpoint(
|
||||
model_to_checkpoint=model,
|
||||
filepath=os.path.join(
|
||||
DIR_OUTPUT,
|
||||
"checkpoints",
|
||||
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
|
||||
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5",
|
||||
),
|
||||
monitor="loss"
|
||||
monitor="loss",
|
||||
),
|
||||
],
|
||||
steps_per_epoch=STEPS_PER_EPOCH,
|
||||
|
@ -282,13 +290,40 @@ if PATH_CHECKPOINT is None:
|
|||
)
|
||||
logger.info(">>> Training complete")
|
||||
logger.info(">>> Plotting graphs")
|
||||
|
||||
plot_metric(history.history["loss"], history.history["val_loss"], "loss", DIR_OUTPUT)
|
||||
plot_metric(history.history["accuracy"], history.history["val_accuracy"], "accuracy", DIR_OUTPUT)
|
||||
plot_metric(history.history["metric_dice_coefficient"], history.history["val_metric_dice_coefficient"], "dice", DIR_OUTPUT)
|
||||
plot_metric(history.history["one_hot_mean_iou"], history.history["val_one_hot_mean_iou"], "mean iou", DIR_OUTPUT)
|
||||
plot_metric(history.history["sensitivity"], history.history["val_sensitivity"], "sensitivity", DIR_OUTPUT)
|
||||
plot_metric(history.history["specificity"], history.history["val_specificity"], "specificity", DIR_OUTPUT)
|
||||
|
||||
plot_metric(
|
||||
history.history["loss"], history.history["val_loss"], "loss", DIR_OUTPUT
|
||||
)
|
||||
plot_metric(
|
||||
history.history["accuracy"],
|
||||
history.history["val_accuracy"],
|
||||
"accuracy",
|
||||
DIR_OUTPUT,
|
||||
)
|
||||
plot_metric(
|
||||
history.history["metric_dice_coefficient"],
|
||||
history.history["val_metric_dice_coefficient"],
|
||||
"dice",
|
||||
DIR_OUTPUT,
|
||||
)
|
||||
plot_metric(
|
||||
history.history["one_hot_mean_iou"],
|
||||
history.history["val_one_hot_mean_iou"],
|
||||
"mean iou",
|
||||
DIR_OUTPUT,
|
||||
)
|
||||
plot_metric(
|
||||
history.history["sensitivity"],
|
||||
history.history["val_sensitivity"],
|
||||
"sensitivity",
|
||||
DIR_OUTPUT,
|
||||
)
|
||||
plot_metric(
|
||||
history.history["specificity"],
|
||||
history.history["val_specificity"],
|
||||
"specificity",
|
||||
DIR_OUTPUT,
|
||||
)
|
||||
|
||||
|
||||
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
|
||||
|
|
|
@ -114,6 +114,8 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
|
|||
if water_threshold is not None: # if water_threshold=None, then regression mode
|
||||
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
|
||||
# BUG it may be a problem we're [height, width, channel] here rather than [height, width], depending on how dlr works
|
||||
else:
|
||||
water = tf.expand_dims(water, axis=-1) # Stack to have a channel, since if water_threshold=None then we would end up with [height, width] instead of [height, width, channel] otherwise
|
||||
if do_remove_isolated_pixels:
|
||||
water = remove_isolated_pixels(water)
|
||||
|
||||
|
|
Loading…
Reference in a new issue