Skip to content

Commit

Permalink
Update python.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Feb 13, 2025
1 parent 8b7e41f commit a1757da
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 10 deletions.
3 changes: 3 additions & 0 deletions source/base.js
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,9 @@ base.Tensor = class {

get data() {
this._read();
if (this._data && this._data.peek) {
this._data = this._data.peek();
}
return this._data;
}

Expand Down
3 changes: 1 addition & 2 deletions source/numpy.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ numpy.ModelFactory = class {
const execution = new python.Execution();
execution.on('resolve', (_, name) => unresolved.add(name));
const stream = context.stream;
const buffer = stream.peek();
const bytes = execution.invoke('io.BytesIO', [buffer]);
const bytes = execution.invoke('io.BytesIO', [stream]);
const array = execution.invoke('numpy.load', [bytes]);
if (unresolved.size > 0) {
const name = unresolved.values().next().value;
Expand Down
34 changes: 29 additions & 5 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -2477,12 +2477,24 @@ python.Execution = class {
this._point = 0;
}
seek(offset) {
if (this._buf.seek) {
this._buf.seek(offset);
}
this._point = offset;
}
read(size) {
const start = this._point;
this._point = size === undefined ? this._buf.length : start + size;
return this._buf.subarray(start, this._point);
read(size, stream) {
if (this._buf.stream && stream) {
return this._buf.stream(size);
}
if (this._buf.peek) {
return this._buf.read(size);
}
if (this._buf instanceof Uint8Array) {
const start = this._point;
this._point = size === undefined ? this._buf.length : start + size;
return this._buf.subarray(start, this._point);
}
throw new python.Error('Unsupported buffer type.');
}
write(data) {
const src = this._buf || new Uint8Array();
Expand All @@ -2491,6 +2503,9 @@ python.Execution = class {
this._buf.set(src, 0);
this._buf.set(data, src.length);
}
getbuffer() {
return new builtins.memoryview(this._buf);
}
});
this.registerType('io.StringIO', class {
constructor() {
Expand Down Expand Up @@ -4232,6 +4247,14 @@ python.Execution = class {
throw new python.Error('Unsupported source.');
}
});
this.registerType('builtins.memoryview', class {
constructor(buf) {
this._buf = buf;
}
get nbytes() {
return this._buf.length;
}
});
this.registerType('builtins.frozenset', class extends Set {
constructor(iterable) {
super();
Expand Down Expand Up @@ -4787,7 +4810,8 @@ python.Execution = class {
throw new python.Error(`Unsupported data type '${header.descr}'.`);
}
const count = shape.length === 0 ? 1 : shape.reduce((a, b) => a * b, 1);
data = file.read(dtype.itemsize * count);
const stream = file.getbuffer().nbytes > 0x1000000;
data = file.read(dtype.itemsize * count, stream);
break;
}
default: {
Expand Down
2 changes: 1 addition & 1 deletion source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -2435,8 +2435,8 @@ numpy.Tensor = class {
constructor(array) {
this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
this.stride = array.strides.map((stride) => stride / array.itemsize);
this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
}
};

Expand Down
5 changes: 3 additions & 2 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -3289,6 +3289,8 @@ view.TensorView = class extends view.Expander {
content.insertBefore(this._saveButton, content.firstChild);
}
}
}).catch((error) => {
content.innerHTML = error.message;
});
}
return content;
Expand Down Expand Up @@ -5574,8 +5576,7 @@ view.Context = class {
const python = await import('./python.js');
const execution = new python.Execution();
for (const [name, stream] of entries) {
const buffer = stream.peek();
const bytes = execution.invoke('io.BytesIO', [buffer]);
const bytes = execution.invoke('io.BytesIO', [stream]);
const array = execution.invoke('numpy.load', [bytes]);
content.set(name, array);
}
Expand Down

0 comments on commit a1757da

Please sign in to comment.