mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
ai: fix summary file writing; make water encoder smaller
This commit is contained in:
parent
389216b391
commit
3e0ca6a315
3 changed files with 11 additions and 6 deletions
|
@ -59,7 +59,7 @@ class RainfallWaterContraster(object):
|
||||||
|
|
||||||
|
|
||||||
def make_model(self):
|
def make_model(self):
|
||||||
return model_rainfallwater_contrastive(batch_size=self.batch_size, **self.kwargs)
|
return model_rainfallwater_contrastive(batch_size=self.batch_size, summary_file=self.filepath_summary, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, filepath_checkpoint):
|
def load_model(self, filepath_checkpoint):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder
|
||||||
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
|
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
|
||||||
from .components.LossContrastive import LossContrastive
|
from .components.LossContrastive import LossContrastive
|
||||||
|
|
||||||
def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048):
|
def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64, feature_dim=2048, summary_file=None):
|
||||||
logger.info(shape_rainfall)
|
logger.info(shape_rainfall)
|
||||||
logger.info(shape_water)
|
logger.info(shape_water)
|
||||||
|
|
||||||
|
@ -27,14 +27,18 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water, batch_size=64,
|
||||||
input_width=rainfall_width,
|
input_width=rainfall_width,
|
||||||
input_height=rainfall_height,
|
input_height=rainfall_height,
|
||||||
channels=rainfall_channels,
|
channels=rainfall_channels,
|
||||||
feature_dim=feature_dim
|
feature_dim=feature_dim,
|
||||||
|
summary_file=summary_file,
|
||||||
|
arch_name="convnext_tiny",
|
||||||
)(input_rainfall)
|
)(input_rainfall)
|
||||||
print("MAKE ENCODER water")
|
print("MAKE ENCODER water")
|
||||||
water = LayerContrastiveEncoder(
|
water = LayerContrastiveEncoder(
|
||||||
input_width=water_width,
|
input_width=water_width,
|
||||||
input_height=water_height,
|
input_height=water_height,
|
||||||
channels=water_channels,
|
channels=water_channels,
|
||||||
feature_dim=feature_dim
|
feature_dim=feature_dim,
|
||||||
|
arch_name="convnext_xtiny",
|
||||||
|
summary_file=summary_file
|
||||||
)(input_water)
|
)(input_water)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,8 @@ from loguru import logger
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from lib.dataset.read_metadata import read_metadata
|
||||||
|
|
||||||
from ..io.readfile import readfile
|
from ..io.readfile import readfile
|
||||||
from .shuffle import shuffle
|
from .shuffle import shuffle
|
||||||
|
|
||||||
|
@ -48,7 +50,6 @@ def make_dataset(filenames, metadata, compression_type="GZIP", parallel_reads_mu
|
||||||
|
|
||||||
|
|
||||||
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_multiplier=1.5):
|
||||||
filepath_meta = os.path.join(dirpath_input, "metadata.json")
|
|
||||||
filepaths = shuffle(list(filter(
|
filepaths = shuffle(list(filter(
|
||||||
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
lambda filepath: str(filepath).endswith(".tfrecord.gz"),
|
||||||
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
[ file.path for file in os.scandir(dirpath_input) ] # .path on a DirEntry object yields the absolute filepath
|
||||||
|
@ -59,7 +60,7 @@ def dataset(dirpath_input, batch_size=64, train_percentage=0.8, parallel_reads_m
|
||||||
filepaths_train = filepaths[:dataset_splitpoint]
|
filepaths_train = filepaths[:dataset_splitpoint]
|
||||||
filepaths_validate = filepaths[dataset_splitpoint:]
|
filepaths_validate = filepaths[dataset_splitpoint:]
|
||||||
|
|
||||||
metadata = json.loads(readfile(filepath_meta))
|
metadata = read_metadata(dirpath_input)
|
||||||
|
|
||||||
dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
dataset_train = make_dataset(filepaths_train, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
||||||
dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
dataset_validate = make_dataset(filepaths_validate, metadata, batch_size=batch_size, parallel_reads_multiplier=parallel_reads_multiplier)
|
||||||
|
|
Loading…
Reference in a new issue