mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
dlr: fix crash
This commit is contained in:
parent
bcf198d47b
commit
36859746ff
1 changed files with 4 additions and 4 deletions
|
@ -249,10 +249,10 @@ def plot_predictions(filepath, input_items, colormap, model):
|
|||
figsize=(18, 14)
|
||||
)
|
||||
|
||||
def get_items_from_batched(dataset, count):
|
||||
def get_inputs_from_batched(dataset, count):
|
||||
result = []
|
||||
for batched in dataset:
|
||||
items = tf.unstack(batched, axis=0)
|
||||
items = tf.unstack(batched[0], axis=0)
|
||||
for item in items:
|
||||
result.append(item)
|
||||
if len(result) >= count:
|
||||
|
@ -261,13 +261,13 @@ def get_items_from_batched(dataset, count):
|
|||
|
||||
plot_predictions(
|
||||
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
||||
get_items_from_batched(dataset_train, 4),
|
||||
get_inputs_from_batched(dataset_train, 4),
|
||||
colormap,
|
||||
model=model
|
||||
)
|
||||
plot_predictions(
|
||||
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
||||
get_items_from_batched(dataset_validate, 4),
|
||||
get_inputs_from_batched(dataset_validate, 4),
|
||||
colormap,
|
||||
model=model
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue