Skip to content

Commit

Permalink
Update ONNX test file (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Feb 10, 2025
1 parent 47348fe commit 97e27c7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
38 changes: 22 additions & 16 deletions source/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,7 @@ onnx.ProtoReader = class {
static async open(context) {
const identifier = context.identifier;
const stream = context.stream;
let offset = 0;
if (stream && stream.length > 5) {
const buffer = stream.peek(Math.min(stream.length, 256));
if (buffer[0] === 0x08 && buffer[1] < 0x0B && buffer[2] === 0x12 && buffer[3] < 64 && (buffer[3] + 4) <= stream.length) {
Expand All @@ -1454,25 +1455,19 @@ onnx.ProtoReader = class {
}
}
const length = buffer[0] | (buffer[1] << 8) | (buffer[2] << 16) | (buffer[3] << 24);
if (length === stream.length - 4) {
stream.seek(4);
try {
const reader = protobuf.BinaryReader.open(stream);
const tags = reader.signature();
if (tags.get(7) === 2) {
stream.seek(4);
return new onnx.ProtoReader(context, 'binary', 'model');
}
} catch {
// continue regardless of error
}
if (length === stream.length - 4 && (buffer[4] === 0x08 || buffer[4] === 0x0A)) {
offset = 4;
}
}
stream.seek(offset);
const binaryTags = await context.tags('pb');
stream.seek(0);
if (binaryTags.size > 0) {
const tags = binaryTags;
if (tags.size === 1 && tags.get(1) === 2) {
stream.seek(offset);
const tags = await context.tags('pb+');
stream.seek(0);
const match = (tags, schema) => {
for (const [key, inner] of schema) {
const value = tags[key];
Expand Down Expand Up @@ -1510,7 +1505,7 @@ onnx.ProtoReader = class {
if (tags.get(1) === 0 && tags.get(2) === 0 && [3, 4, 5, 6].filter((tag) => tags.get(tag)).length <= 1) {
const schema = [[1,0],[2,0],[4,2],[5,2],[7,2],[8,2],[9,2]];
if (schema.every(([key, value]) => !tags.has(key) || tags.get(key) === value)) {
return new onnx.ProtoReader(context, 'binary', 'tensor');
return new onnx.ProtoReader(context, 'binary', 'tensor', offset);
}
}
// GraphProto
Expand All @@ -1537,7 +1532,7 @@ onnx.ProtoReader = class {
if (nodeBuffer) {
const nameBuffer = decode(nodeBuffer, 4);
if (nameBuffer && nameBuffer.every((c) => c > 0x20 && c < 0x7f)) {
return new onnx.ProtoReader(context, 'binary', 'graph');
return new onnx.ProtoReader(context, 'binary', 'graph', offset);
}
}
}
Expand All @@ -1546,7 +1541,7 @@ onnx.ProtoReader = class {
if (tags.get(7) === 2) {
const schema = [[1,0],[2,2],[3,2],[4,2],[5,0],[6,2],[7,2],[8,2],[14,2],[20,2]];
if (schema.every(([key, value]) => !tags.has(key) || tags.get(key) === value)) {
return new onnx.ProtoReader(context, 'binary', 'model');
return new onnx.ProtoReader(context, 'binary', 'model', offset);
}
}
}
Expand All @@ -1571,11 +1566,12 @@ onnx.ProtoReader = class {
return undefined;
}

constructor(context, encoding, type) {
constructor(context, encoding, type, offset) {
this.name = 'onnx.proto';
this.context = context;
this.encoding = encoding;
this.type = type;
this.offset = offset || 0;
this.locations = new Map();
}

Expand Down Expand Up @@ -1607,6 +1603,10 @@ onnx.ProtoReader = class {
break;
}
case 'binary': {
const stream = context.stream;
if (this.offset) {
stream.seek(this.offset);
}
switch (this.type) {
case 'tensor': {
// TensorProto
Expand Down Expand Up @@ -1637,6 +1637,9 @@ onnx.ProtoReader = class {
// GraphProto
try {
const reader = await context.read('protobuf.binary');
if (this.offset) {
stream.seek(0);
}
this.model = new onnx.proto.ModelProto();
this.model.graph = onnx.proto.GraphProto.decode(reader);
this.format = 'ONNX';
Expand All @@ -1662,6 +1665,9 @@ onnx.ProtoReader = class {
throw new onnx.Error('Unsupported ONNX format type.');
}
}
if (this.offset) {
stream.seek(0);
}
break;
}
default: {
Expand Down
4 changes: 2 additions & 2 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4708,9 +4708,9 @@
{
"type": "onnx",
"target": "super_resolution.onnx",
"source": "https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/b385b1b242dc89a35dd808235b885ed8a19aedc1/super_resolution.onnx",
"source": "https://github.com/user-attachments/files/18726996/super_resolution.onnx.zip[super_resolution.onnx]",
"format": "ONNX",
"link": "https://gist.github.com/zhreshold/bcda4716699ac97ea44f791c24310193"
"link": "https://github.com/lutzroeder/netron/issues/6"
},
{
"type": "onnx",
Expand Down

0 comments on commit 97e27c7

Please sign in to comment.