ResNetRSV2 → ConvNeXt

ironically this makes the model simpler o/
This commit is contained in:
Starbeamrainbowlabs 2022-08-31 18:51:01 +01:00
parent 3d614d105b
commit 51cf08a386
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
5 changed files with 188 additions and 25 deletions

View file

@ -1,17 +1,17 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.applications.resnet_v2 import ResNet50V2 # from tensorflow.keras.applications.resnet_v2 import ResNet50V2
# from transformers import TFConvNextModel, ConvNextConfig
from ..helpers.summarywriter import summarylogger from ..helpers.summarywriter import summarylogger
from .convnext import make_convnext
class LayerContrastiveEncoder(tf.keras.layers.Layer): class LayerContrastiveEncoder(tf.keras.layers.Layer):
def __init__(self, input_width, input_height, channels, feature_dim=200, **kwargs): def __init__(self, input_width, input_height, channels, feature_dim=2048, **kwargs):
"""Creates a new contrastive learning encoder layer. """Creates a new contrastive learning encoder layer.
Note that the input format MUST be channels_last. This is because Tensorflow/Keras' Dense layer does NOT support specifying an axis. Go complain to them, not me.
While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer. While this is intended for contrastive learning, this can (in theory) be used anywhere as it's just a generic wrapper layer.
The key feature here is that it does not care about the input size or the number of channels. The key feature here is that it does not care about the input size or the number of channels.
Currently it uses a ResNetV2 internally, but an upgrade to ConvNeXt is planned once Tensorflow Keras' implementation comes out of nightly and into stable. Currently it uses a ConvNeXt internally, but an upgrade to Tensorflow's internal ConvNeXt implementation is planned once it comes out of nightly and into stable.
We would use ResNetRS (as it's technically superior), but the implementation is bad and in places outright *wrong* O.o
Args: Args:
feature_dim (int, optional): The size of the features dimension in the output shape. Note that there are *two* feature dimensions outputted - one for the left, and one for the right. They will both be in the form [ batch_size, feature_dim ]. Set to a low value (e.g. 25) to be able to plot a sensible a parallel coordinates graph. Defaults to 200. feature_dim (int, optional): The size of the features dimension in the output shape. Note that there are *two* feature dimensions outputted - one for the left, and one for the right. They will both be in the form [ batch_size, feature_dim ]. Set to a low value (e.g. 25) to be able to plot a sensible a parallel coordinates graph. Defaults to 200.
@ -26,20 +26,16 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
self.param_channels = channels self.param_channels = channels
self.param_feature_dim = feature_dim self.param_feature_dim = feature_dim
"""The main ResNet model that forms the encoder. """The main ConvNeXt model that forms the encoder.
Note that both the left AND the right go through the SAME encoder!s
""" """
self.encoder = ResNet50V2( self.encoder = make_convnext(
include_top=False, input_shape = (self.param_input_width, self.param_input_height, self.param_channels),
input_shape=(self.param_channels, self.param_input_width, self.param_input_height), classifier_activation = tf.nn.relu, # this is not actually a classifier, but rather a feature encoder
weights=None, num_classes = self.param_feature_dim # size of the feature dimension, see the line above this one
pooling=None,
data_format="channels_first"
) )
"""Small sequential stack of layers that control the size of the outputted feature dimension. # """Small sequential stack of layers that control the size of the outputted feature dimension.
""" # """
self.embedding = tf.keras.layers.Dense(self.param_feature_dim) # self.embedding = tf.keras.layers.Dense(self.param_feature_dim)
self.embedding_input_shape = [None, 2048] # The output shape of the above ResNet AFTER reshaping.
summarylogger(self.encoder) summarylogger(self.encoder)
@ -59,9 +55,10 @@ class LayerContrastiveEncoder(tf.keras.layers.Layer):
def call(self, input_thing): def call(self, input_thing):
result = self.encoder(input_thing) result = self.encoder(input_thing)
shape_ksize = result.shape[1] # The encoder is handled by the ConvNeXt model \o/
result = tf.nn.avg_pool(result, ksize=shape_ksize, strides=1, padding="VALID") # shape_ksize = result.shape[1]
# result = tf.nn.avg_pool(result, ksize=shape_ksize, strides=1, padding="VALID")
target_shape = [ -1, result.shape[-1] ] # target_shape = [ -1, result.shape[-1] ]
result = self.embedding(tf.reshape(result, target_shape)) # result = self.embedding(tf.reshape(result, target_shape))
return result return result

View file

@ -0,0 +1,164 @@
from unicodedata import name
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
# Code from https://github.com/leanderme/ConvNeXt-Tensorflow/blob/main/ConvNeXt.ipynb
from .LayerConvNeXtGamma import LayerConvNeXtGamma
kernel_initial = tf.keras.initializers.TruncatedNormal(stddev=0.2)
bias_initial = tf.keras.initializers.Constant(value=0)
depths_dims = dict(
convnext_xtiny = (dict(depths=[3, 3, 6, 3], dims=[66, 132, 264, 528])),
# architectures from: https://github.com/facebookresearch/ConvNeXt
# A ConvNet for the 2020s: https://arxiv.org/abs/2201.03545
convnext_tiny = (dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])),
convnext_small = (dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])),
convnext_base = (dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])),
convnext_large = (dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])),
convnext_xlarge = (dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])),
)
def make_convnext(input_shape, arch_name="convnext_tiny", **kwargs):
"""Makes a ConvNeXt model.
Returns a tf.keras.Model.
Args:
input_shape (int[]): The input shape of the tensor that will be fed to the ConvNeXt model. This is necessary as we make the model using the functional API and thus we need to make an Input layer.
arch_name (str, optional): The name of the preset ConvNeXt model architecture to use. Defaults to "convnext_tiny".
"""
layer_in = tf.keras.layers.Input(
shape = input_shape
)
layer_out = convnext(layer_in, **depths_dims[arch_name], **kwargs)
return tf.keras.Model(
input = layer_in,
output = layer_out
)
def convnext(
x,
include_top = True,
num_classes = 1000,
depths = [3, 3, 9, 3],
dims = [96, 192, 384, 768],
drop_path_rate = 0.,
classifier_activation = 'softmax'
# Note that we CAN'T add data_format here, 'cause Dense doesn't support specifying the axis
):
assert len(depths) == len(dims)
def forward_features(x):
i = 0
for depth, dim in zip(depths, dims):
if i == 0:
x = tf.keras.layers.Conv2D(
dim,
kernel_size = 4,
strides = 4,
padding = "valid",
name = "downsample_layers.0.0_conv"
)(x)
x = tf.keras.layers.LayerNormalization(
epsilon = 1e-6,
name = "downsample_layers.0.0_norm"
)(x)
else:
x = tf.keras.layers.LayerNormalization(
epsilon = 1e-6,
name = "stages." + str(i) + "." + str(k) + ".downsample_norm"
)(x)
x = tf.keras.layers.Conv2D(
dim,
kernel_size = 2,
strides = 2,
padding ='same',
kernel_initializer = kernel_initial,
bias_initializer = bias_initial,
name = "stages." + str(i) + "." + str(k) + ".downsample_conv"
)(x)
for k in range(depth):
x = add_convnext_block(
x,
dim,
drop_path_rate,
prefix = "stages." + str(i) + "." + str(k),
)
i = i +1
return x
x = forward_features(x)
if include_top:
x = tf.keras.layers.GlobalAveragePooling2D(
name = 'avg'
)(x)
x = tf.keras.layers.LayerNormalization(
epsilon = 1e-6,
name = "norm",
)(x)
x = tf.keras.layers.Dense(
num_classes,
activation = classifier_activation,
kernel_initializer = kernel_initial,
bias_initializer = bias_initial,
name = "head"
)(x)
else:
x = tf.keras.layers.GlobalAveragePooling2D(name='avg')(x)
x = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="norm")(x)
return x
def add_convnext_block(y, dim, drop_prob=0, prefix=""):
skip = y
y = tf.keras.layers.DepthwiseConv2D(
kernel_size=7,
padding='same',
name = f'{prefix}.dwconv'
)(y)
y = tf.keras.layers.LayerNormalization(
epsilon=1e-6,
name=f'{prefix}.norm'
)(y)
y = tf.keras.layers.Dense(
4 * dim,
name=f'{prefix}.pwconv1'
)(y)
y = tf.keras.layers.Activation(
'gelu',
name=f'{prefix}.act'
)(y)
y = tf.keras.layers.Dense(
dim,
name=f'{prefix}.pwconv2'
)(y)
y = LayerConvNeXtGamma(
const_val = 1e-6,
dim = dim,
name = f'{prefix}.gamma'
)(y)
y = tfa.layers.StochasticDepth(
drop_prob,
name = f'{prefix}.drop_path'
)([skip, y])
return y

View file

@ -6,7 +6,7 @@ from .components.LayerContrastiveEncoder import LayerContrastiveEncoder
from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut from .components.LayerCheeseMultipleOut import LayerCheeseMultipleOut
from .components.LossContrastive import LossContrastive from .components.LossContrastive import LossContrastive
def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=200): def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=2048):
logger.info(shape_rainfall) logger.info(shape_rainfall)
logger.info(shape_water) logger.info(shape_water)
@ -19,7 +19,7 @@ def model_rainfallwater_contrastive(shape_rainfall, shape_water, feature_dim=200
shape=shape_rainfall shape=shape_rainfall
) )
input_water = tf.keras.layers.Input( input_water = tf.keras.layers.Input(
shape=shape_water shape=(water_width, water_height, water_channels)
) )

View file

@ -20,6 +20,8 @@ def parse_item(item):
rainfall = tf.io.parse_tensor(parsed["rainfallradar"], out_type=tf.float32) rainfall = tf.io.parse_tensor(parsed["rainfallradar"], out_type=tf.float32)
water = tf.io.parse_tensor(parsed["waterdepth"], out_type=tf.float32) water = tf.io.parse_tensor(parsed["waterdepth"], out_type=tf.float32)
# [channels, width, height] → [width, height, channels] - ref ConvNeXt does not support data_format=channels_first
rainfall = tf.transpose(rainfall, [1, 2, 0])
# [width, height] → [width, height, channels] # [width, height] → [width, height, channels]
water = tf.expand_dims(water, axis=-1) water = tf.expand_dims(water, axis=-1)

View file

@ -13,7 +13,7 @@ def parse_args():
# parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True) # parser.add_argument("--config", "-c", help="Filepath to the TOML config file to load.", required=True)
parser.add_argument("--input", "-i", help="Path to input directory containing the .tfrecord.gz files to pretrain with", required=True) parser.add_argument("--input", "-i", help="Path to input directory containing the .tfrecord.gz files to pretrain with", required=True)
parser.add_argument("--output", "-o", help="Path to output directory to write output to (will be automatically created if it doesn't exist)", required=True) parser.add_argument("--output", "-o", help="Path to output directory to write output to (will be automatically created if it doesn't exist)", required=True)
parser.add_argument("--feature-dim", help="The size of the output feature dimension of the model [default: 200].", type=int) parser.add_argument("--feature-dim", help="The size of the output feature dimension of the model [default: 2048].", type=int)
parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int) parser.add_argument("--batch-size", help="Sets the batch size [default: 64].", type=int)
parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.") parser.add_argument("--reads-multiplier", help="Optional. The multiplier for the number of files we should read from at once. Defaults to 1.5, which means read ceil(NUMBER_OF_CORES * 1.5) files at once. Set to a higher number of systems with high read latency to avoid starving the GPU of data.")
@ -24,7 +24,7 @@ def run(args):
if (not hasattr(args, "batch_size")) or args.batch_size == None: if (not hasattr(args, "batch_size")) or args.batch_size == None:
args.batch_size = 64 args.batch_size = 64
if (not hasattr(args, "feature_dim")) or args.feature_dim == None: if (not hasattr(args, "feature_dim")) or args.feature_dim == None:
args.feature_dim = 200 args.feature_dim = 2048
if (not hasattr(args, "read_multiplier")) or args.read_multiplier == None: if (not hasattr(args, "read_multiplier")) or args.read_multiplier == None:
args.read_multiplier = 1.5 args.read_multiplier = 1.5