mirror of
https://github.com/sbrl/research-rainfallradar
synced 2024-11-22 09:13:01 +00:00
ResNetRSV2 → ConvNeXt
ironically this makes the model simpler o/
This commit is contained in:
parent
3d614d105b
commit
51cf08a386
5 changed files with 188 additions and 25 deletions
|
@ -1,17 +1,17 @@
|
|||
import tensorflow as tf
|
||||
|
||||
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
|
||||
# from transformers import TFConvNextModel, ConvNextConfig
|
||||
# from tensorflow.keras.applications.resnet_v2 import ResNet50V2
|
||||
from ..helpers.summarywriter import summarylogger
|
||||
from .convnext import make_convnext
|
||||
|
||||
class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
||||
|
||||
def __init__(self, input_width, input_height, channels, feature_dim=200, **kwargs):
|
||||
def __init__(self, input_width, input_height, channels, feature_dim=2048, **kwargs):
|
||||
"""Creates a new contrastive learning encoder layer.
|
||||
Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me.
|
||||
While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer.
|
||||
The key feature here is that it does not care about the input size or the number of channels.
|
||||
Currently it uses a ResNetV2 internally, but an upgrade to ConvNeXt is planned once Tensorflow Keras' implementation comes out of nightly and into stable.
|
||||
We would use ResNetRS (as it's technically superior), but the implementation is bad and in places outright *wrong* O.o
|
||||
Currently it uses a ConvNeXt internally, but an upgrade to Tensorflow's internal ConvNeXt implementation is planned once it comes out of nightly and into stable.
|
||||
|
||||
Args:
|
||||
feature_dim (int, optional): The size of the features dimension in the output shape. Note that there are *two* feature dimensions outputted - one for the left, and one for the right. They will both be in the form [ batch_size, feature_dim ]. Set to a low value (e.g. 25) to be able to plot a sensible a parallel coordinates graph. Defaults to 200.
|
||||
|
@ -26,20 +26,16 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
|||
self.param_channels = channels
|
||||
self.param_feature_dim = feature_dim
|
||||
|
||||
"""The main ResNet model that forms the encoder.
|
||||
Note that both the left AND the right go through the SAME encoder!s
|
||||
"""The main ConvNeXt model that forms the encoder.
|
||||
"""
|
||||
self.encoder = ResNet50V2(
|
||||
include_top=False,
|
||||
input_shape=(self.param_channels, self.param_input_width, self.param_input_height),
|
||||
weights=None,
|
||||
pooling=None,
|
||||
data_format="channels_first"
|
||||
self.encoder = make_convnext(
|
||||
input_shape = (self.param_input_width, self.param_input_height, self.param_channels),
|
||||
classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder
|
||||
num_classes = self.param_feature_dim # size of the feature dimension, see the line above this one
|
||||
)
|
||||
"""Small sequential stack of layers that control the size of the outputted feature dimension.
|
||||
"""
|
||||
self.embedding = tf.keras.layers.Dense(self.param_feature_dim)
|
||||
self.embedding_input_shape = [None, 2048] # The output shape of the above ResNet AFTER reshaping.
|
||||
# """Small sequential stack of layers that control the size of the outputted feature dimension.
|
||||
# """
|
||||
# self.embedding = tf.keras.layers.Dense(self.param_feature_dim)
|
||||
|
||||
summarylogger(self.encoder)
|
||||
|
||||
|
@ -59,9 +55,10 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
|
|||
def call(self, input_thing):
|
||||
result = self.encoder(input_thing)
|
||||
|
||||
shape_ksize = result.shape[1]
|
||||
result = tf.nn.avg_pool(result, ksize=shape_ksize, strides=1, padding="VALID")
|
||||
# The encoder is handled by the ConvNeXt model \o/
|
||||
# shape_ksize = result.shape[1]
|
||||
# result = tf.nn.avg_pool(result, ksize=shape_ksize, strides=1, padding="VALID")
|
||||
|
||||
target_shape = [ -1, result.shape[-1] ]
|
||||
result = self.embedding(tf.reshape(result, target_shape))
|
||||
# target_shape = [ -1, result.shape[-1] ]
|
||||
# result = self.embedding(tf.reshape(result, target_shape))
|
||||
return result
|
164
aimodel/src/lib/ai/components/convnext.py
Normal file
164
aimodel/src/lib/ai/components/convnext.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
from unicodedata import name
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_addons as tfa
|
||||
|
||||
|
||||
# Code from https://github.com/leanderme/ConvNeXt-Tensorflow/blob/main/ConvNeXt.ipynb
|
||||
from .LayerConvNeXtGamma import LayerConvNeXtGamma
|
||||
|
||||
kernel_initial = tf.keras.initializers.TruncatedNormal(stddev=0.2)
|
||||
bias_initial = tf.keras.initializers.Constant(value=0)
|
||||
|
||||
depths_dims = dict(
|
||||
convnext_xtiny = (dict(depths=[3, 3, 6, 3], dims=[66, 132, 264, 528])),
|
||||
# architectures from: https://github.com/facebookresearch/ConvNeXt
|
||||
# A ConvNet for the 2020s: https://arxiv.org/abs/2201.03545
|
||||
convnext_tiny = (dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])),
|
||||
convnext_small = (dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])),
|
||||
convnext_base = (dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])),
|
||||
convnext_large = (dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])),
|
||||
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
|
||||
)
|
||||
|
||||
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
|
||||
"""Makes a ConvNeXt model.
|
||||
Returns a tf.keras.Model.
|
||||
Args:
|
||||
input_shape (int[]): The input shape of the tensor that will be fed to the ConvNeXt model. This is necessary as we make the model using the functional API and thus we need to make an Input layer.
|
||||
arch_name (str, optional): The name of the preset ConvNeXt model architecture to use. Defaults to "convnext_tiny".
|
||||
"""
|
||||
layer_in = tf.keras.layers.Input(
|
||||
shape = input_shape
|
||||
)
|
||||
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
|
||||
return tf.keras.Model(
|
||||
input = layer_in,
|
||||
output = layer_out
|
||||
)
|
||||
|
||||
|
||||
def convnext(
|
||||
x,
|
||||
include_top = True,
|
||||
num_classes = 1000,
|
||||
depths = [3, 3, 9, 3],
|
||||
dims = [96, 192, 384, 768],
|
||||
drop_path_rate = 0.,
|
||||
classifier_activation = 'softmax'
|
||||
# Note that we CAN'T add data_format here, 'cause Dense doesn't support specifying the axis
|
||||
):
|
||||
|
||||
assert len(depths) == len(dims)
|
||||
|
||||
def forward_features(x):
|
||||
i = 0
|
||||
for depth, dim in zip(depths, dims):
|
||||
|
||||
if i == 0:
|
||||
x = tf.keras.layers.Conv2D(
|
||||
dim,
|
||||
kernel_size = 4,
|
||||
strides = 4,
|
||||
padding = "valid",
|
||||
name = "downsample_layers.0.0_conv"
|
||||
)(x)
|
||||
x = tf.keras.layers.LayerNormalization(
|
||||
epsilon = 1e-6,
|
||||
name = "downsample_layers.0.0_norm"
|
||||
)(x)
|
||||
else:
|
||||
x = tf.keras.layers.LayerNormalization(
|
||||
epsilon = 1e-6,
|
||||
name = "stages." + str(i) + "." + str(k) + ".downsample_norm"
|
||||
)(x)
|
||||
x = tf.keras.layers.Conv2D(
|
||||
dim,
|
||||
kernel_size = 2,
|
||||
strides = 2,
|
||||
padding ='same',
|
||||
kernel_initializer = kernel_initial,
|
||||
bias_initializer = bias_initial,
|
||||
name = "stages." + str(i) + "." + str(k) + ".downsample_conv"
|
||||
)(x)
|
||||
|
||||
|
||||
for k in range(depth):
|
||||
x = add_convnext_block(
|
||||
x,
|
||||
dim,
|
||||
drop_path_rate,
|
||||
prefix = "stages." + str(i) + "." + str(k),
|
||||
)
|
||||
i = i +1
|
||||
|
||||
return x
|
||||
|
||||
x = forward_features(x)
|
||||
|
||||
if include_top:
|
||||
x = tf.keras.layers.GlobalAveragePooling2D(
|
||||
name = 'avg'
|
||||
)(x)
|
||||
x = tf.keras.layers.LayerNormalization(
|
||||
epsilon = 1e-6,
|
||||
name = "norm",
|
||||
)(x)
|
||||
|
||||
|
||||
x = tf.keras.layers.Dense(
|
||||
num_classes,
|
||||
activation = classifier_activation,
|
||||
kernel_initializer = kernel_initial,
|
||||
bias_initializer = bias_initial,
|
||||
name = "head"
|
||||
)(x)
|
||||
else:
|
||||
x = tf.keras.layers.GlobalAveragePooling2D(name='avg')(x)
|
||||
x = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="norm")(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def add_convnext_block(y, dim, drop_prob=0, prefix=""):
|
||||
skip = y
|
||||
|
||||
y = tf.keras.layers.DepthwiseConv2D(
|
||||
kernel_size=7,
|
||||
padding='same',
|
||||
name = f'{prefix}.dwconv'
|
||||
)(y)
|
||||
|
||||
y = tf.keras.layers.LayerNormalization(
|
||||
epsilon=1e-6,
|
||||
name=f'{prefix}.norm'
|
||||
)(y)
|
||||
|
||||
y = tf.keras.layers.Dense(
|
||||
4 * dim,
|
||||
name=f'{prefix}.pwconv1'
|
||||
)(y)
|
||||
|
||||
|
||||
y = tf.keras.layers.Activation(
|
||||
'gelu',
|
||||
name=f'{prefix}.act'
|
||||
)(y)
|
||||
|
||||
y = tf.keras.layers.Dense(
|
||||
dim,
|
||||
name=f'{prefix}.pwconv2'
|
||||
)(y)
|
||||
|
||||
y = LayerConvNeXtGamma(
|
||||
const_val = 1e-6,
|
||||
dim = dim,
|
||||
name = f'{prefix}.gamma'
|
||||
)(y)
|
||||
|
||||
y = tfa.layers.StochasticDepth(
|
||||
drop_prob,
|
||||
name = f'{prefix}.drop_path'
|
||||
)([skip, y])
|
||||
|
||||
return y
|
|
@ -6,7 +6,7 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder
|
|||
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
|
||||
from .components.LossContrastive import LossContrastive
|
||||
|
||||
def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=200):
|
||||
def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=2048):
|
||||
logger.info(shape_rainfall)
|
||||
logger.info(shape_water)
|
||||
|
||||
|
@ -19,7 +19,7 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=200
|
|||
shape=shape_rainfall
|
||||
)
|
||||
input_water = tf.keras.layers.Input(
|
||||
shape=shape_water
|
||||
shape=(water_width, water_height, water_channels)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ def parse_item(item):
|
|||
rainfall = tf.io.parse_tensor(parsed["rainfallradar"], out_type=tf.float32)
|
||||
water = tf.io.parse_tensor(parsed["waterdepth"], out_type=tf.float32)
|
||||
|
||||
# [channels, width, height] → [width, height, channels] - ref ConvNeXt does not support data_format=channels_first
|
||||
rainfall = tf.transpose(rainfall, [1, 2, 0])
|
||||
# [width, height] → [width, height, channels]
|
||||
water = tf.expand_dims(water, axis=-1)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ def parse_args():
|
|||
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
|
||||
parser.add_argument("--input", "-i", help="Path to input directory containing the .tfrecord.gz files to pretrain with", required=True)
|
||||
parser.add_argument("--output", "-o", help="Path to output directory to write output to (will be automatically created if it doesn't exist)", required=True)
|
||||
parser.add_argument("--feature-dim", help="The size of the output feature dimension of the model [default: 200].", type=int)
|
||||
parser.add_argument("--feature-dim", help="The size of the output feature dimension of the model [default: 2048].", type=int)
|
||||
parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int)
|
||||
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
|
||||
|
||||
|
@ -24,7 +24,7 @@ def run(args):
|
|||
if (not hasattr(args, "batch_size")) or args.batch_size == None:
|
||||
args.batch_size = 64
|
||||
if (not hasattr(args, "feature_dim")) or args.feature_dim == None:
|
||||
args.feature_dim = 200
|
||||
args.feature_dim = 2048
|
||||
if (not hasattr(args, "read_multiplier")) or args.read_multiplier == None:
|
||||
args.read_multiplier = 1.5
|
||||
|
||||
|
|
Loading…
Reference in a new issue