Support sequential models

such models don't have an explicit connected to column.
In this case, we assume every subsequent layer is to the one before.
This commit is contained in:
Starbeamrainbowlabs 2023-02-18 02:04:40 +00:00
parent 6555100f46
commit 3d777d0de1
Signed by: sbrl
GPG key ID: 1BE5172E637709C2
2 changed files with 29 additions and 9 deletions

View file

@ -19,6 +19,8 @@ export default function make_graph(summary) {
const nomnoml_source = result.join("\n"); const nomnoml_source = result.join("\n");
console.info(`[make_graph] nomnoml source:\n${nomnoml_source}`);
const svg = nomnoml.renderSvg(nomnoml_source); const svg = nomnoml.renderSvg(nomnoml_source);
return svg; return svg;
} }

View file

@ -4,31 +4,45 @@
export default function parse_summary(text) { export default function parse_summary(text) {
const lines = text.trim().split(/\r?\n/); const lines = text.trim().split(/\r?\n/);
const has_connected_to = lines[2].search("Connected to") > -1;
const layers_raw = lines.slice( const layers_raw = lines.slice(
lines.findIndex(line => line.startsWith("====")) + 1, lines.findIndex(line => line.startsWith("====")) + 1,
lines.findLastIndex(line => line.startsWith("====")) lines.findLastIndex(line => line.startsWith("===="))
); );
const layers = []; const layers = [];
const sep_output_shape = lines[2].search("Output Shape");
const sep_param_hash = lines[2].search("Param #");
const sep_connected_to = has_connected_to ? lines[2].search("Connected to") : lines[2].length;
let layer_prev = null;
let acc = []; let acc = [];
for (const line of layers_raw) { for (const line of layers_raw) {
if(line.trim().length == 0) { if(line.trim().length == 0) {
if(acc.length === 0) continue; if(acc.length === 0) continue;
console.log(acc.map(layer_line => layer_line.substring(65).trim()).join(""));
console.log(`DEBUG:params`, acc.map(layer_line => layer_line.substring(sep_param_hash, sep_connected_to).trim()));
// Handle parsed item // Handle parsed item
const result = { const result = {
name_raw: acc.map(layer_line => layer_line.substring(0, 32).trim()).join(""), name_raw: acc.map(layer_line => layer_line.substring(0, sep_output_shape).trim()).join(""),
output_shape: acc.map(layer_line => layer_line.substring(32, 53).trim()).join(""), output_shape: acc.map(layer_line => layer_line.substring(sep_output_shape, sep_param_hash).trim()).join(""),
params: parseInt(acc.map(layer_line => layer_line.substring(53, 65).trim()).join("")), params: parseInt(acc.map(layer_line => layer_line.substring(sep_param_hash, sep_connected_to).trim()).join("")),
connected_to: JSON.parse(acc.map(layer_line => layer_line.substring(65).trim()).join("")
.replace(/'/g, '"'))
.map(connected_name => connected_name.replace(/(\[0\])+$/, ""))
}; };
if(has_connected_to) {
result.connected_to = JSON.parse(acc.map(layer_line => layer_line.substring(sep_connected_to).trim()).join("")
.replace(/'/g, '"'))
.map(connected_name => connected_name.replace(/(\[0\])+$/, ""));
}
else {
if(layer_prev !== null)
result.connected_to = [ layer_prev.name ];
}
result.type = result.name_raw.match(/ \(([^)]+)\)/)[1]; result.type = result.name_raw.match(/ \(([^)]+)\)/)[1];
result.name = result.name_raw.split(/\s+/)[0]; result.name = result.name_raw.split(/\s+/)[0];
layer_prev = result;
layers.push(result); layers.push(result);
acc.length = 0; acc.length = 0;
@ -40,6 +54,10 @@ export default function parse_summary(text) {
const edges = []; const edges = [];
for(const layer of layers) { for(const layer of layers) {
if(!(layer.connected_to instanceof Array)) {
console.info(`[edge detect] layer.connected_to is not an instance of Array:`, layer);
continue;
}
for(const connection of layer.connected_to) { for(const connection of layer.connected_to) {
edges.push({ edges.push({
from: connection, from: connection,