Support sequential models
such models don't have an explicit connected to column. In this case, we assume every subsequent layer is to the one before.
This commit is contained in:
parent
6555100f46
commit
3d777d0de1
2 changed files with 29 additions and 9 deletions
|
@ -19,6 +19,8 @@ export default function make_graph(summary) {
|
||||||
|
|
||||||
const nomnoml_source = result.join("\n");
|
const nomnoml_source = result.join("\n");
|
||||||
|
|
||||||
|
console.info(`[make_graph] nomnoml source:\n${nomnoml_source}`);
|
||||||
|
|
||||||
const svg = nomnoml.renderSvg(nomnoml_source);
|
const svg = nomnoml.renderSvg(nomnoml_source);
|
||||||
return svg;
|
return svg;
|
||||||
}
|
}
|
|
@ -3,32 +3,46 @@
|
||||||
|
|
||||||
export default function parse_summary(text) {
|
export default function parse_summary(text) {
|
||||||
const lines = text.trim().split(/\r?\n/);
|
const lines = text.trim().split(/\r?\n/);
|
||||||
|
|
||||||
|
const has_connected_to = lines[2].search("Connected to") > -1;
|
||||||
|
|
||||||
const layers_raw = lines.slice(
|
const layers_raw = lines.slice(
|
||||||
lines.findIndex(line => line.startsWith("====")) + 1,
|
lines.findIndex(line => line.startsWith("====")) + 1,
|
||||||
lines.findLastIndex(line => line.startsWith("===="))
|
lines.findLastIndex(line => line.startsWith("===="))
|
||||||
);
|
);
|
||||||
const layers = [];
|
const layers = [];
|
||||||
|
|
||||||
|
const sep_output_shape = lines[2].search("Output Shape");
|
||||||
|
const sep_param_hash = lines[2].search("Param #");
|
||||||
|
const sep_connected_to = has_connected_to ? lines[2].search("Connected to") : lines[2].length;
|
||||||
|
|
||||||
|
let layer_prev = null;
|
||||||
let acc = [];
|
let acc = [];
|
||||||
for (const line of layers_raw) {
|
for (const line of layers_raw) {
|
||||||
if(line.trim().length == 0) {
|
if(line.trim().length == 0) {
|
||||||
if(acc.length === 0) continue;
|
if(acc.length === 0) continue;
|
||||||
console.log(acc.map(layer_line => layer_line.substring(65).trim()).join(""));
|
|
||||||
|
console.log(`DEBUG:params`, acc.map(layer_line => layer_line.substring(sep_param_hash, sep_connected_to).trim()));
|
||||||
// Handle parsed item
|
// Handle parsed item
|
||||||
const result = {
|
const result = {
|
||||||
name_raw: acc.map(layer_line => layer_line.substring(0, 32).trim()).join(""),
|
name_raw: acc.map(layer_line => layer_line.substring(0, sep_output_shape).trim()).join(""),
|
||||||
output_shape: acc.map(layer_line => layer_line.substring(32, 53).trim()).join(""),
|
output_shape: acc.map(layer_line => layer_line.substring(sep_output_shape, sep_param_hash).trim()).join(""),
|
||||||
params: parseInt(acc.map(layer_line => layer_line.substring(53, 65).trim()).join("")),
|
params: parseInt(acc.map(layer_line => layer_line.substring(sep_param_hash, sep_connected_to).trim()).join("")),
|
||||||
connected_to: JSON.parse(acc.map(layer_line => layer_line.substring(65).trim()).join("")
|
|
||||||
.replace(/'/g, '"'))
|
|
||||||
.map(connected_name => connected_name.replace(/(\[0\])+$/, ""))
|
|
||||||
};
|
};
|
||||||
|
if(has_connected_to) {
|
||||||
|
result.connected_to = JSON.parse(acc.map(layer_line => layer_line.substring(sep_connected_to).trim()).join("")
|
||||||
|
.replace(/'/g, '"'))
|
||||||
|
.map(connected_name => connected_name.replace(/(\[0\])+$/, ""));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if(layer_prev !== null)
|
||||||
|
result.connected_to = [ layer_prev.name ];
|
||||||
|
}
|
||||||
|
|
||||||
result.type = result.name_raw.match(/ \(([^)]+)\)/)[1];
|
result.type = result.name_raw.match(/ \(([^)]+)\)/)[1];
|
||||||
result.name = result.name_raw.split(/\s+/)[0];
|
result.name = result.name_raw.split(/\s+/)[0];
|
||||||
|
|
||||||
|
layer_prev = result;
|
||||||
layers.push(result);
|
layers.push(result);
|
||||||
|
|
||||||
acc.length = 0;
|
acc.length = 0;
|
||||||
|
@ -40,6 +54,10 @@ export default function parse_summary(text) {
|
||||||
const edges = [];
|
const edges = [];
|
||||||
|
|
||||||
for(const layer of layers) {
|
for(const layer of layers) {
|
||||||
|
if(!(layer.connected_to instanceof Array)) {
|
||||||
|
console.info(`[edge detect] layer.connected_to is not an instance of Array:`, layer);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
for(const connection of layer.connected_to) {
|
for(const connection of layer.connected_to) {
|
||||||
edges.push({
|
edges.push({
|
||||||
from: connection,
|
from: connection,
|
||||||
|
|
Loading…
Reference in a new issue