mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
dlr: fix crash
This commit is contained in:
parent
bcf198d47b
commit
36859746ff
1 changed files with 4 additions and 4 deletions
|
@ -249,10 +249,10 @@ def plot_predictions(filepath, input_items, colormap, model):
|
||||||
figsize=(18, 14)
|
figsize=(18, 14)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_items_from_batched(dataset, count):
|
def get_inputs_from_batched(dataset, count):
|
||||||
result = []
|
result = []
|
||||||
for batched in dataset:
|
for batched in dataset:
|
||||||
items = tf.unstack(batched, axis=0)
|
items = tf.unstack(batched[0], axis=0)
|
||||||
for item in items:
|
for item in items:
|
||||||
result.append(item)
|
result.append(item)
|
||||||
if len(result) >= count:
|
if len(result) >= count:
|
||||||
|
@ -261,13 +261,13 @@ def get_items_from_batched(dataset, count):
|
||||||
|
|
||||||
plot_predictions(
|
plot_predictions(
|
||||||
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
os.path.join(DIR_OUTPUT, "predict_train.png"),
|
||||||
get_items_from_batched(dataset_train, 4),
|
get_inputs_from_batched(dataset_train, 4),
|
||||||
colormap,
|
colormap,
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
plot_predictions(
|
plot_predictions(
|
||||||
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
os.path.join(DIR_OUTPUT, "predict_validate.png"),
|
||||||
get_items_from_batched(dataset_validate, 4),
|
get_inputs_from_batched(dataset_validate, 4),
|
||||||
colormap,
|
colormap,
|
||||||
model=model
|
model=model
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue