mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-12-22 22:25:01 +00:00
ai: fix summary file writing; make water encoder smaller
This commit is contained in:
parent
389216b391
commit
3e0ca6a315
3 changed files with 11 additions and 6 deletions
|
@ -59,7 +59,7 @@ class RainfallWaterContraster(object):
|
|||
|
||||
|
||||
def make_model(self):
|
||||
return model_rainfallwater_contrastive(batch_size=self.batch_size, **self.kwargs)
|
||||
return model_rainfallwater_contrastive(batch_size=self.batch_size, summary_file=self.filepath_summary, **self.kwargs)
|
||||
|
||||
|
||||
def load_model(self, filepath_checkpoint):
|
||||
|
|
|
@ -6,7 +6,7 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder
|
|||
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
|
||||
from .components.LossContrastive import LossContrastive
|
||||
|
||||
def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048):
|
||||
def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048, summary_file=None):
|
||||
logger.info(shape_rainfall)
|
||||
logger.info(shape_water)
|
||||
|
||||
|
@ -27,14 +27,18 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64,
|
|||
input_width=rainfall_width,
|
||||
input_height=rainfall_height,
|
||||
channels=rainfall_channels,
|
||||
feature_dim=feature_dim
|
||||
feature_dim=feature_dim,
|
||||
summary_file=summary_file,
|
||||
arch_name="convnext_tiny",
|
||||
)(input_rainfall)
|
||||
print("MAKE ENCODER water")
|
||||
water = LayerContrastiveEncoder(
|
||||
input_width=water_width,
|
||||
input_height=water_height,
|
||||
channels=water_channels,
|
||||
feature_dim=feature_dim
|
||||
feature_dim=feature_dim,
|
||||
arch_name="convnext_xtiny",
|
||||
summary_file=summary_file
|
||||
)(input_water)
|
||||
|
||||
|
||||
|
|
|
@ -6,6 +6,8 @@ from loguru import logger
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
from lib.dataset.read_metadata import read_metadata
|
||||
|
||||
from ..io.readfile import readfile
|
||||
from .shuffle import shuffle
|
||||
|
||||
|
@ -48,7 +50,6 @@ def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_mu
|
|||
|
||||
|
||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
||||
filepath_meta = os.path.join(dirpath_input, "metadata.json")
|
||||
filepaths = shuffle(list(filter(
|
||||
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
||||
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||
|
@ -59,7 +60,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
|
|||
filepaths_train = filepaths[:dataset_splitpoint]
|
||||
filepaths_validate = filepaths[dataset_splitpoint:]
|
||||
|
||||
metadata = json.loads(readfile(filepath_meta))
|
||||
metadata = read_metadata(dirpath_input)
|
||||
|
||||
dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
||||
dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
||||
|
|
Loading…
Reference in a new issue