ai: fix summary file writing; make water encoder smaller

This commit is contained in:
Starbeamrainbowlabs 2022-09-02 17:51:45 +01:00
parent 389216b391
commit 3e0ca6a315
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
3 changed files with 11 additions and 6 deletions

View file

@ -59,7 +59,7 @@ class RainfallWaterContraster(object):
def make_model(self):
return model_rainfallwater_contrastive(batch_size=self.batch_size, **self.kwargs)
return model_rainfallwater_contrastive(batch_size=self.batch_size, summary_file=self.filepath_summary, **self.kwargs)
def load_model(self, filepath_checkpoint):

View file

@ -6,7 +6,7 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
from .components.LossContrastive import LossContrastive
def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048):
def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048, summary_file=None):
logger.info(shape_rainfall)
logger.info(shape_water)
@ -27,14 +27,18 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64,
input_width=rainfall_width,
input_height=rainfall_height,
channels=rainfall_channels,
feature_dim=feature_dim
feature_dim=feature_dim,
summary_file=summary_file,
arch_name="convnext_tiny",
)(input_rainfall)
print("MAKE ENCODER water")
water = LayerContrastiveEncoder(
input_width=water_width,
input_height=water_height,
channels=water_channels,
feature_dim=feature_dim
feature_dim=feature_dim,
arch_name="convnext_xtiny",
summary_file=summary_file
)(input_water)

View file

@ -6,6 +6,8 @@ from loguru import logger
import tensorflow as tf
from lib.dataset.read_metadata import read_metadata
from ..io.readfile import readfile
from .shuffle import shuffle
@ -48,7 +50,6 @@ def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_mu
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
filepath_meta = os.path.join(dirpath_input, "metadata.json")
filepaths = shuffle(list(filter(
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
@ -59,7 +60,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
filepaths_train = filepaths[:dataset_splitpoint]
filepaths_validate = filepaths[dataset_splitpoint:]
metadata = json.loads(readfile(filepath_meta))
metadata = read_metadata(dirpath_input)
dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)