Skip to content

Commit

Permalink
Merge pull request #427 from hotg-ai/feature/tensor-from-typed-array
Browse files Browse the repository at this point in the history
Add a Tensor.fromTypedArray() constructor
  • Loading branch information
Michael Bryan authored Apr 19, 2022
2 parents 900cb6b + f29a081 commit 7316818
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
2 changes: 1 addition & 1 deletion bindings/web/rune/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@hotg-ai/rune",
"version": "0.11.6",
"version": "0.11.8",
"description": "Execute Runes inside a JavaScript environment.",
"repository": "https://github.com/hotg-ai/rune",
"homepage": "https://hotg.dev/",
Expand Down
9 changes: 9 additions & 0 deletions bindings/web/rune/src/Tensor.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,13 @@ describe("Tensor", () => {

expect(Array.from(typed)).toEqual(numbers.slice(3, 6));
});

it("can be constructed from a typed array", () => {
const values = [1, 2, 3, 4, -5, -6];
const raw = new Int16Array(values);

const tensor = Tensor.fromTypedArray("i16", [6], raw);

expect(Array.from(tensor.asTypedArray("i16"))).toEqual(values);
});
});
33 changes: 30 additions & 3 deletions bindings/web/rune/src/Tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Shape from "./Shape";
const BigUint64ArrayShim = global.BigUint64Array ?? class { constructor() { throw new Error("BigUint64Array is not supported on this device"); } };
const BigInt64ArrayShim = global.BigInt64Array ?? class { constructor() { throw new Error("BigInt64Array is not supported on this device"); } };

const typedArrays = {
const typedArrayConstructors = {
"f64": Float64Array,
"i64": BigInt64ArrayShim,
"u64": BigUint64ArrayShim,
Expand All @@ -21,6 +21,12 @@ const typedArrays = {
"i8": Int8Array,
} as const;

type TypedArrayConstructors = typeof typedArrayConstructors;

export type TypedArrays = {
[Key in keyof TypedArrayConstructors]: InstanceType<TypedArrayConstructors[Key]>;
}

/**
* An opaque tensor.
*/
Expand All @@ -39,6 +45,25 @@ export default class Tensor {
this.elements = elements;
}

/**
* Construct a new Tensor from a typed array containing its flattened
* elements in row-major order.
*
* @param elementType The type of the element
* @param dimensions The tensor's dimensions
* @param elements The elements
* @returns
*/
public static fromTypedArray<S extends keyof TypedArrays>(
elementType: S,
dimensions: readonly number[],
elements: TypedArrays[S],
): Tensor {
const { buffer, byteLength, byteOffset } = elements;
const shape = new Shape(elementType, [...dimensions]);
return new Tensor(shape, new Uint8Array(buffer, byteOffset, byteLength));
}

/**
* View this tensor's data as an array of 64-bit floats.
*
Expand Down Expand Up @@ -102,14 +127,14 @@ export default class Tensor {
*/
public asTypedArray(elementType: "u8"): Uint8ClampedArray;

public asTypedArray(elementType: keyof typeof typedArrays): ArrayBuffer {
public asTypedArray(elementType: keyof typeof typedArrayConstructors): ArrayBuffer {
if (this.shape.type != elementType) {
throw new Error(`Attempting to interpret a ${this.shape.toString()} as a ${elementType} tensor`);
}

const { buffer, byteOffset, byteLength } = this.elements;
const length = byteLength / Shape.ByteSize[this.shape.type];
const constructor = typedArrays[elementType];
const constructor = typedArrayConstructors[elementType];

return new constructor(buffer, byteOffset, length);
}
Expand All @@ -122,3 +147,5 @@ export default class Tensor {
return this.shape.dimensions;
}
}

const x = Tensor.fromTypedArray

0 comments on commit 7316818

Please sign in to comment.