spaces → spaces

This commit is contained in:
Starbeamrainbowlabs 2023-01-05 19:17:44 +00:00
parent 56a501f8a9
commit 67b8a2c6c0
Signed by: sbrl
GPG key ID: 1BE5172E637709C2

View file

@ -60,68 +60,68 @@ logger.info("Validation Dataset:", dataset_validate)
# ██ ██ ██████ ██████ ███████ ███████ # ██ ██ ██████ ██████ ███████ ███████
def convolution_block( def convolution_block(
block_input, block_input,
num_filters=256, num_filters=256,
kernel_size=3, kernel_size=3,
dilation_rate=1, dilation_rate=1,
padding="same", padding="same",
use_bias=False, use_bias=False,
): ):
x = tf.keras.layers.Conv2D( x = tf.keras.layers.Conv2D(
num_filters, num_filters,
kernel_size=kernel_size, kernel_size=kernel_size,
dilation_rate=dilation_rate, dilation_rate=dilation_rate,
padding="same", padding="same",
use_bias=use_bias, use_bias=use_bias,
kernel_initializer=tf.keras.initializers.HeNormal(), kernel_initializer=tf.keras.initializers.HeNormal(),
)(block_input) )(block_input)
x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.layers.BatchNormalization()(x)
return tf.nn.relu(x) return tf.nn.relu(x)
def DilatedSpatialPyramidPooling(dspp_input): def DilatedSpatialPyramidPooling(dspp_input):
dims = dspp_input.shape dims = dspp_input.shape
x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input) x = tf.keras.layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
x = convolution_block(x, kernel_size=1, use_bias=True) x = convolution_block(x, kernel_size=1, use_bias=True)
out_pool = tf.keras.layers.UpSampling2D( out_pool = tf.keras.layers.UpSampling2D(
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear", size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
)(x) )(x)
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1) out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6) out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12) out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18) out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18]) x = tf.keras.layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1) output = convolution_block(x, kernel_size=1)
return output return output
def DeeplabV3Plus(image_size, num_classes, num_channels=3): def DeeplabV3Plus(image_size, num_classes, num_channels=3):
model_input = tf.keras.Input(shape=(image_size, image_size, num_channels)) model_input = tf.keras.Input(shape=(image_size, image_size, num_channels))
resnet50 = tf.keras.applications.ResNet50( resnet50 = tf.keras.applications.ResNet50(
weights="imagenet" if num_channels == 3 else None, weights="imagenet" if num_channels == 3 else None,
include_top=False, input_tensor=model_input include_top=False, input_tensor=model_input
) )
x = resnet50.get_layer("conv4_block6_2_relu").output x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x) x = DilatedSpatialPyramidPooling(x)
input_a = tf.keras.layers.UpSampling2D( input_a = tf.keras.layers.UpSampling2D(
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]), size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
interpolation="bilinear", interpolation="bilinear",
)(x) )(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1) input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b]) x = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x) x = convolution_block(x)
x = convolution_block(x) x = convolution_block(x)
x = tf.keras.layers.UpSampling2D( x = tf.keras.layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]), size=(image_size // x.shape[1], image_size // x.shape[2]),
interpolation="bilinear", interpolation="bilinear",
)(x) )(x)
model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x) model_output = tf.keras.layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
return tf.keras.Model(inputs=model_input, outputs=model_output) return tf.keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8) model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES, num_channels=8)
@ -138,9 +138,9 @@ summarywriter(model, os.path.join(DIR_OUTPUT, "summary.txt"))
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile( model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=loss, loss=loss,
metrics=["accuracy"], metrics=["accuracy"],
) )
logger.info(">>> Beginning training") logger.info(">>> Beginning training")
history = model.fit(dataset_train, history = model.fit(dataset_train,
@ -195,59 +195,59 @@ plt.close()
# Loading the Colormap # Loading the Colormap
colormap = loadmat( colormap = loadmat(
PATH_COLOURMAP PATH_COLOURMAP
)["colormap"] )["colormap"]
colormap = colormap * 100 colormap = colormap * 100
colormap = colormap.astype(np.uint8) colormap = colormap.astype(np.uint8)
def infer(model, image_tensor): def infer(model, image_tensor):
predictions = model.predict(tf.expand_dims((image_tensor), axis=0)) predictions = model.predict(tf.expand_dims((image_tensor), axis=0))
predictions = tf.squeeze(predictions) predictions = tf.squeeze(predictions)
predictions = tf.argmax(predictions, axis=2) predictions = tf.argmax(predictions, axis=2)
return predictions return predictions
def decode_segmentation_masks(mask, colormap, n_classes): def decode_segmentation_masks(mask, colormap, n_classes):
r = np.zeros_like(mask).astype(np.uint8) r = np.zeros_like(mask).astype(np.uint8)
g = np.zeros_like(mask).astype(np.uint8) g = np.zeros_like(mask).astype(np.uint8)
b = np.zeros_like(mask).astype(np.uint8) b = np.zeros_like(mask).astype(np.uint8)
for l in range(0, n_classes): for l in range(0, n_classes):
idx = mask == l idx = mask == l
r[idx] = colormap[l, 0] r[idx] = colormap[l, 0]
g[idx] = colormap[l, 1] g[idx] = colormap[l, 1]
b[idx] = colormap[l, 2] b[idx] = colormap[l, 2]
rgb = np.stack([r, g, b], axis=2) rgb = np.stack([r, g, b], axis=2)
return rgb return rgb
def get_overlay(image, coloured_mask): def get_overlay(image, coloured_mask):
image = tf.keras.preprocessing.image.array_to_img(image) image = tf.keras.preprocessing.image.array_to_img(image)
image = np.array(image).astype(np.uint8) image = np.array(image).astype(np.uint8)
overlay = cv2.addWeighted(image, 0.35, coloured_mask, 0.65, 0) overlay = cv2.addWeighted(image, 0.35, coloured_mask, 0.65, 0)
return overlay return overlay
def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)): def plot_samples_matplotlib(filepath, display_list, figsize=(5, 3)):
_, axes = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize) _, axes = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize)
for i in range(len(display_list)): for i in range(len(display_list)):
if display_list[i].shape[-1] == 3: if display_list[i].shape[-1] == 3:
axes[i].imshow(tf.keras.preprocessing.image.array_to_img(display_list[i])) axes[i].imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
else: else:
axes[i].imshow(display_list[i]) axes[i].imshow(display_list[i])
plt.savefig(filepath) plt.savefig(filepath)
def plot_predictions(filepath, input_items, colormap, model): def plot_predictions(filepath, input_items, colormap, model):
for input_tensor in input_items: for input_tensor in input_items:
prediction_mask = infer(image_tensor=input_tensor, model=model) prediction_mask = infer(image_tensor=input_tensor, model=model)
prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20) prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20)
overlay = get_overlay(input_tensor, prediction_colormap) overlay = get_overlay(input_tensor, prediction_colormap)
plot_samples_matplotlib( plot_samples_matplotlib(
filepath, filepath,
[input_tensor, overlay, prediction_colormap], [input_tensor, overlay, prediction_colormap],
figsize=(18, 14) figsize=(18, 14)
) )
def get_items_from_batched(dataset, count): def get_items_from_batched(dataset, count):
result = [] result = []