Skip to content

Commit

Permalink
feat: add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
daviddaytw committed Dec 8, 2024
1 parent da24ce0 commit b47c0d8
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 0 deletions.
13 changes: 13 additions & 0 deletions babel.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module.exports = {
presets: [
["@babel/preset-env", { targets: { node: "current" } }],
"@babel/preset-typescript",
"@babel/preset-react",
],
plugins: [
["@babel/plugin-transform-flow-strip-types"],
["@babel/plugin-transform-class-properties", { loose: true }],
["@babel/plugin-transform-private-methods", { loose: true }],
["@babel/plugin-transform-private-property-in-object", { loose: true }],
],
};
38 changes: 38 additions & 0 deletions jest.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module.exports = {
preset: "react-native",
moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"],
transformIgnorePatterns: [
"node_modules/(?!(" +
"react-native|" +
"@react-native|" +
"@xenova/transformers|" +
"text-encoding-polyfill" +
")/)",
],
setupFiles: [
"./node_modules/react-native/jest/setup.js",
"./src/__tests__/setup.js",
],
testRegex: "(/__tests__/.*(?<!setup)\\.(test|spec))\\.[jt]sx?$",
testEnvironment: "node",
transform: {
"^.+\\.(js|jsx|ts|tsx)$": [
"babel-jest",
{ configFile: "./babel.config.js" },
],
},
globals: {
"ts-jest": {
babelConfig: true,
tsconfig: "tsconfig.json",
},
},
collectCoverage: true,
coverageDirectory: "coverage",
coverageReporters: ["text", "lcov"],
collectCoverageFrom: [
"src/**/*.{js,jsx,ts,tsx}",
"!src/**/*.d.ts",
"!src/__tests__/**",
],
};
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
},
"devDependencies": {
"@babel/core": "^7.20.0",
"@babel/plugin-transform-class-properties": "^7.25.9",
"@babel/plugin-transform-private-methods": "^7.25.9",
"@babel/plugin-transform-private-property-in-object": "^7.25.9",
"@react-native/eslint-config": "^0.74.85",
"@release-it/conventional-changelog": "^8.0.1",
"@tsconfig/react-native": "^3.0.5",
Expand Down
61 changes: 61 additions & 0 deletions src/__tests__/setup.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Mock text-encoding-polyfill
jest.mock("text-encoding-polyfill", () => ({}));

// Mock fetch for config loading
global.fetch = jest.fn(() =>
Promise.resolve({
arrayBuffer: () =>
Promise.resolve(
Uint8Array.from(
JSON.stringify({
eos_token_id: 2,
num_key_value_heads: 32,
hidden_size: 4096,
num_attention_heads: 32,
num_hidden_layers: 32,
})
.split("")
.map((c) => c.charCodeAt(0)),
).buffer,
),
}),
);

// Mock InferenceSession
jest.mock("onnxruntime-react-native", () => ({
InferenceSession: {
create: jest.fn().mockResolvedValue({
run: jest.fn().mockResolvedValue({
logits: {
data: new Float32Array([0.1, 0.2, 0.3, 0.4]),
dims: [1, 1, 4],
type: "float32",
},
}),
release: jest.fn(),
}),
},
env: {
logLevel: "error",
},
Tensor: jest.fn().mockImplementation((type, data, dims) => ({
type,
data,
dims,
size: data.length,
dispose: jest.fn(),
})),
}));

// Mock transformers
jest.mock("@xenova/transformers", () => ({
env: {
allowRemoteModels: true,
allowLocalModels: false,
},
AutoTokenizer: {
from_pretrained: jest.fn().mockResolvedValue({
decode: jest.fn((_tokens, _options) => "decoded text"),
}),
},
}));
160 changes: 160 additions & 0 deletions src/__tests__/text-generation.model.test.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import { TextGeneration } from "../models/text-generation";
import { Tensor } from "onnxruntime-react-native";

// Mock onnxruntime-react-native
jest.mock("onnxruntime-react-native", () => ({
Tensor: jest.fn().mockImplementation((type, data, dims) => ({
type,
data,
dims,
size: data.length,
})),
}));

// Create a test-specific subclass to access protected properties
class TestableTextGeneration extends TextGeneration {
public getSession() {
return this.sess;
}

public setSession(session: any) {
this.sess = session;
}

public getFeed() {
return this.feed;
}

public getEos() {
return this.eos;
}
}

describe("TextGeneration Model", () => {
let model: TestableTextGeneration;
let mockRunCount: number;

beforeEach(() => {
mockRunCount = 0;
model = new TestableTextGeneration();
});

describe("initializeFeed", () => {
it("should reset output tokens", () => {
model.outputTokens = [1n, 2n, 3n];
model.initializeFeed();
expect(model.outputTokens).toEqual([]);
});
});

describe("generate", () => {
const mockCallback = jest.fn();
const mockTokens = [1n, 2n]; // Initial tokens

beforeEach(() => {
mockCallback.mockClear();
mockRunCount = 0;
});

it("should generate tokens until EOS token is found", async () => {
model.setSession({
run: jest.fn().mockImplementation(() => {
mockRunCount++;
return Promise.resolve({
logits: {
data: new Float32Array([0.1, 0.2, 0.3, 2.0]), // highest value at index 3
dims: [1, 1, 4],
type: "float32",
},
});
}),
});

const result = await model.generate(mockTokens, mockCallback, {
maxTokens: 10,
});
expect(result.length).toBeGreaterThan(0);
expect(mockCallback).toHaveBeenCalled();
});

it("should respect maxTokens limit", async () => {
const maxTokens = 5;
model.setSession({
run: jest.fn().mockImplementation(() => {
mockRunCount++;
return Promise.resolve({
logits: {
data: new Float32Array([0.1, 0.2, 0.3, 0.1]), // will generate token 2 (index with highest value)
dims: [1, 1, 4],
type: "float32",
},
});
}),
});

const result = await model.generate(mockTokens, mockCallback, {
maxTokens,
});
// Initial tokens (2) + generated tokens should not exceed maxTokens (5)
expect(result.length).toBeLessThanOrEqual(maxTokens);
expect(mockRunCount).toBeLessThanOrEqual(maxTokens - mockTokens.length);
});

it("should throw error if session is undefined", async () => {
model.setSession(undefined);
await expect(
model.generate(mockTokens, mockCallback, { maxTokens: 10 }),
).rejects.toThrow("Session is undefined");
});

it("should create correct tensors for input", async () => {
model.setSession({
run: jest.fn().mockResolvedValue({
logits: {
data: new Float32Array([0.1, 0.2, 0.3, 0.4]),
dims: [1, 1, 4],
type: "float32",
},
}),
});

await model.generate(mockTokens, mockCallback, { maxTokens: 10 });
expect(Tensor).toHaveBeenCalledWith("int64", expect.any(BigInt64Array), [
1,
mockTokens.length,
]);
});

it("should handle generation with attention mask", async () => {
model.setSession({
run: jest.fn().mockResolvedValue({
logits: {
data: new Float32Array([0.1, 0.2, 0.3, 0.4]),
dims: [1, 1, 4],
type: "float32",
},
}),
});

const result = await model.generate(mockTokens, mockCallback, {
maxTokens: 10,
});
const feed = model.getFeed();
expect(feed.attention_mask).toBeDefined();
expect(result).toBeDefined();
});
});

describe("release", () => {
it("should release session resources", async () => {
const mockSession = {
release: jest.fn().mockResolvedValue(undefined),
};
model.setSession(mockSession);

await model.release();
expect(mockSession.release).toHaveBeenCalled();
expect(model.getSession()).toBeUndefined();
});
});
});
109 changes: 109 additions & 0 deletions src/__tests__/text-generation.pipeline.test.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import TextGenerationPipeline from "../pipelines/text-generation";
import type { PreTrainedTokenizer } from "@xenova/transformers";

// Mock the transformers library
jest.mock("@xenova/transformers", () => {
// Create a mock tokenizer function with the correct type
const mockTokenizerFn = Object.assign(
jest
.fn<Promise<{ input_ids: bigint[] }>, [string, any]>()
.mockResolvedValue({ input_ids: [1n, 2n] }),
{
decode: jest.fn((_tokens: bigint[], _options: unknown) => "decoded text"),
},
) as unknown as PreTrainedTokenizer;

return {
env: {
allowRemoteModels: true,
allowLocalModels: false,
},
AutoTokenizer: {
from_pretrained: jest.fn().mockResolvedValue(mockTokenizerFn),
},
};
});

// Mock the model
jest.mock("../models/text-generation", () => {
return {
TextGeneration: jest.fn().mockImplementation(() => ({
initializeFeed: jest.fn(),
generate: jest.fn().mockImplementation((tokens) => {
// Return tokens without calling callback to avoid double decoding
return Promise.resolve(tokens);
}),
load: jest.fn(),
release: jest.fn(),
outputTokens: [],
})),
};
});

describe("TextGenerationPipeline", () => {
beforeEach(() => {
jest.clearAllMocks();
// Reset module state
jest.isolateModules(() => {
require("../pipelines/text-generation");
});
});

describe("init", () => {
it("should initialize with default options", async () => {
await TextGenerationPipeline.init("test-model", "test-path");
expect(
require("@xenova/transformers").AutoTokenizer.from_pretrained,
).toHaveBeenCalledWith("test-model");
});

it("should initialize with custom options", async () => {
await TextGenerationPipeline.init("test-model", "test-path", {
show_special: true,
max_tokens: 100,
});
expect(
require("@xenova/transformers").AutoTokenizer.from_pretrained,
).toHaveBeenCalledWith("test-model");
});
});

describe("generate", () => {
beforeEach(async () => {
await TextGenerationPipeline.init("test-model", "test-path");
});

it("should generate text from prompt", async () => {
const result = await TextGenerationPipeline.generate("test prompt");
expect(result).toBe("decoded text");
});

it("should call callback with generated text", async () => {
const callback = jest.fn();
await TextGenerationPipeline.generate("test prompt", callback);
expect(callback).toHaveBeenCalledWith("decoded text");
});

it("should throw error if not initialized", async () => {
// Reset module state to clear tokenizer
jest.resetModules();
const freshPipeline = require("../pipelines/text-generation").default;
await expect(freshPipeline.generate("test")).rejects.toThrow(
"Tokenizer undefined, please initialize first.",
);
});
});

describe("release", () => {
it("should release resources", async () => {
await TextGenerationPipeline.init("test-model", "test-path");
await TextGenerationPipeline.release();
// Reset module state to clear tokenizer
jest.resetModules();
const freshPipeline = require("../pipelines/text-generation").default;
await expect(freshPipeline.generate("test")).rejects.toThrow(
"Tokenizer undefined, please initialize first.",
);
});
});
});
Loading

0 comments on commit b47c0d8

Please sign in to comment.