mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 06:05:01 +00:00
dlr: implement stage 2 WIP of opt regression support via rmse
....this is so stupid. There's a reason why diffusion models are a thing and we don't use image segmentation models for this!!
This commit is contained in:
parent
58a7e22a4d
commit
fe2f8b3821
2 changed files with 62 additions and 25 deletions
|
@ -234,47 +234,55 @@ def plot_metric(train, val, name, dir_output):
|
||||||
|
|
||||||
if PATH_CHECKPOINT is None:
|
if PATH_CHECKPOINT is None:
|
||||||
loss_fn = None
|
loss_fn = None
|
||||||
|
metrics = [
|
||||||
|
"accuracy",
|
||||||
|
dice_coefficient,
|
||||||
|
mean_iou(),
|
||||||
|
sensitivity(), # How many true positives were accurately predicted
|
||||||
|
specificity, # How many true negatives were accurately predicted?
|
||||||
|
]
|
||||||
if LOSS == "cross-entropy-dice":
|
if LOSS == "cross-entropy-dice":
|
||||||
loss_fn = LossCrossEntropyDice(log_cosh=DICE_LOG_COSH)
|
loss_fn = LossCrossEntropyDice(log_cosh=DICE_LOG_COSH)
|
||||||
elif LOSS == "cross-entropy":
|
elif LOSS == "cross-entropy":
|
||||||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
elif LOSS == "root-mean-squared-error":
|
||||||
|
loss_fn = tf.keras.metrics.RootMeanSquaredError()
|
||||||
|
metrics = [tf.keras.metrics.RootMeanSquaredError()] # Others don't make sense w/o this
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice).")
|
raise Exception(
|
||||||
|
f"Error: Unknown loss function '{LOSS}' (possible values: cross-entropy, cross-entropy-dice)."
|
||||||
|
)
|
||||||
|
|
||||||
model.compile(
|
model.compile(
|
||||||
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
|
optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
|
||||||
loss=loss_fn,
|
loss=loss_fn,
|
||||||
metrics=[
|
metrics=metrics,
|
||||||
"accuracy",
|
|
||||||
dice_coefficient,
|
|
||||||
mean_iou(),
|
|
||||||
sensitivity(), # How many true positives were accurately predicted
|
|
||||||
specificity # How many true negatives were accurately predicted?
|
|
||||||
],
|
|
||||||
steps_per_execution=STEPS_PER_EXECUTION,
|
steps_per_execution=STEPS_PER_EXECUTION,
|
||||||
jit_compile=JIT_COMPILE
|
jit_compile=JIT_COMPILE,
|
||||||
)
|
)
|
||||||
logger.info(">>> Beginning training")
|
logger.info(">>> Beginning training")
|
||||||
history = model.fit(dataset_train,
|
history = model.fit(
|
||||||
|
dataset_train,
|
||||||
validation_data=dataset_validate,
|
validation_data=dataset_validate,
|
||||||
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
|
# test_data=dataset_test, # Nope, it doesn't have a param like this so it's time to do this the *hard* way
|
||||||
epochs=EPOCHS,
|
epochs=EPOCHS,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
CallbackExtraValidation({ # `model,` removed 'ref apparently exists by default????? ehhhh...???
|
CallbackExtraValidation(
|
||||||
"test": dataset_test # Can be None because it handles that
|
{ # `model,` removed 'ref apparently exists by default????? ehhhh...???
|
||||||
}),
|
"test": dataset_test # Can be None because it handles that
|
||||||
|
}
|
||||||
|
),
|
||||||
tf.keras.callbacks.CSVLogger(
|
tf.keras.callbacks.CSVLogger(
|
||||||
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"),
|
filename=os.path.join(DIR_OUTPUT, "metrics.tsv"), separator="\t"
|
||||||
separator="\t"
|
|
||||||
),
|
),
|
||||||
CallbackCustomModelCheckpoint(
|
CallbackCustomModelCheckpoint(
|
||||||
model_to_checkpoint=model,
|
model_to_checkpoint=model,
|
||||||
filepath=os.path.join(
|
filepath=os.path.join(
|
||||||
DIR_OUTPUT,
|
DIR_OUTPUT,
|
||||||
"checkpoints",
|
"checkpoints",
|
||||||
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5"
|
"checkpoint_e{epoch:d}_loss{loss:.3f}.hdf5",
|
||||||
),
|
),
|
||||||
monitor="loss"
|
monitor="loss",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
steps_per_epoch=STEPS_PER_EPOCH,
|
steps_per_epoch=STEPS_PER_EPOCH,
|
||||||
|
@ -282,13 +290,40 @@ if PATH_CHECKPOINT is None:
|
||||||
)
|
)
|
||||||
logger.info(">>> Training complete")
|
logger.info(">>> Training complete")
|
||||||
logger.info(">>> Plotting graphs")
|
logger.info(">>> Plotting graphs")
|
||||||
|
|
||||||
plot_metric(history.history["loss"], history.history["val_loss"], "loss", DIR_OUTPUT)
|
plot_metric(
|
||||||
plot_metric(history.history["accuracy"], history.history["val_accuracy"], "accuracy", DIR_OUTPUT)
|
history.history["loss"], history.history["val_loss"], "loss", DIR_OUTPUT
|
||||||
plot_metric(history.history["metric_dice_coefficient"], history.history["val_metric_dice_coefficient"], "dice", DIR_OUTPUT)
|
)
|
||||||
plot_metric(history.history["one_hot_mean_iou"], history.history["val_one_hot_mean_iou"], "mean iou", DIR_OUTPUT)
|
plot_metric(
|
||||||
plot_metric(history.history["sensitivity"], history.history["val_sensitivity"], "sensitivity", DIR_OUTPUT)
|
history.history["accuracy"],
|
||||||
plot_metric(history.history["specificity"], history.history["val_specificity"], "specificity", DIR_OUTPUT)
|
history.history["val_accuracy"],
|
||||||
|
"accuracy",
|
||||||
|
DIR_OUTPUT,
|
||||||
|
)
|
||||||
|
plot_metric(
|
||||||
|
history.history["metric_dice_coefficient"],
|
||||||
|
history.history["val_metric_dice_coefficient"],
|
||||||
|
"dice",
|
||||||
|
DIR_OUTPUT,
|
||||||
|
)
|
||||||
|
plot_metric(
|
||||||
|
history.history["one_hot_mean_iou"],
|
||||||
|
history.history["val_one_hot_mean_iou"],
|
||||||
|
"mean iou",
|
||||||
|
DIR_OUTPUT,
|
||||||
|
)
|
||||||
|
plot_metric(
|
||||||
|
history.history["sensitivity"],
|
||||||
|
history.history["val_sensitivity"],
|
||||||
|
"sensitivity",
|
||||||
|
DIR_OUTPUT,
|
||||||
|
)
|
||||||
|
plot_metric(
|
||||||
|
history.history["specificity"],
|
||||||
|
history.history["val_specificity"],
|
||||||
|
"specificity",
|
||||||
|
DIR_OUTPUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
|
# ██ ███ ██ ███████ ███████ ██████ ███████ ███ ██ ██████ ███████
|
||||||
|
|
|
@ -114,6 +114,8 @@ def parse_item(metadata, output_size=100, input_size="same", water_threshold=0.1
|
||||||
if water_threshold is not None: # if water_threshold=None, then regression mode
|
if water_threshold is not None: # if water_threshold=None, then regression mode
|
||||||
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
|
water = tf.cast(tf.math.greater_equal(water, water_threshold), dtype=tf.float32)
|
||||||
# BUG it may be a problem we're [height, width, channel] here rather than [height, width], depending on how dlr works
|
# BUG it may be a problem we're [height, width, channel] here rather than [height, width], depending on how dlr works
|
||||||
|
else:
|
||||||
|
water = tf.expand_dims(water, axis=-1) # Stack to have a channel, since if water_threshold=None then we would end up with [height, width] instead of [height, width, channel] otherwise
|
||||||
if do_remove_isolated_pixels:
|
if do_remove_isolated_pixels:
|
||||||
water = remove_isolated_pixels(water)
|
water = remove_isolated_pixels(water)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue