diff --git a/aimodel/src/lib/ai/components/CallbackNBatchCsv.py b/aimodel/src/lib/ai/components/CallbackNBatchCsv.py index de7491b..e04058c 100644 --- a/aimodel/src/lib/ai/components/CallbackNBatchCsv.py +++ b/aimodel/src/lib/ai/components/CallbackNBatchCsv.py @@ -9,7 +9,7 @@ class CallbackNBatchCsv(tf.keras.callbacks.Callback): self.n_batches = n_batches self.separator = separator - self.handle = handle_open(filepath) + self.handle = handle_open(filepath, "w") self.batches_seen = 0 diff --git a/aimodel/src/lib/io/handle_open.py b/aimodel/src/lib/io/handle_open.py index a167ea0..bd3f046 100644 --- a/aimodel/src/lib/io/handle_open.py +++ b/aimodel/src/lib/io/handle_open.py @@ -2,7 +2,10 @@ import io import gzip -def handle_open(filepath, mode): +def handle_open(filepath, mode, force_textwrite_gzip=True): + if mode == "w" and mode.endswith(".gz") and force_textwrite_gzip: + mode = "wt" + if filepath.endswith(".gz"): return gzip.open(filepath, mode) else: