dlr: fix crash

This commit is contained in:
Starbeamrainbowlabs 2023-01-06 17:13:35 +00:00
parent bcf198d47b
commit 36859746ff
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -249,10 +249,10 @@ def plot_predictions(filepath, input_items, colormap, model):
figsize=(18, 14)
)
def get_items_from_batched(dataset, count):
def get_inputs_from_batched(dataset, count):
result = []
for batched in dataset:
items = tf.unstack(batched, axis=0)
items = tf.unstack(batched[0], axis=0)
for item in items:
result.append(item)
if len(result) >= count:
@ -261,13 +261,13 @@ def get_items_from_batched(dataset, count):
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_train.png"),
get_items_from_batched(dataset_train, 4),
get_inputs_from_batched(dataset_train, 4),
colormap,
model=model
)
plot_predictions(
os.path.join(DIR_OUTPUT, "predict_validate.png"),
get_items_from_batched(dataset_validate, 4),
get_inputs_from_batched(dataset_validate, 4),
colormap,
model=model
)