-
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
da24ce0
commit b47c0d8
Showing
7 changed files
with
387 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
module.exports = { | ||
presets: [ | ||
["@babel/preset-env", { targets: { node: "current" } }], | ||
"@babel/preset-typescript", | ||
"@babel/preset-react", | ||
], | ||
plugins: [ | ||
["@babel/plugin-transform-flow-strip-types"], | ||
["@babel/plugin-transform-class-properties", { loose: true }], | ||
["@babel/plugin-transform-private-methods", { loose: true }], | ||
["@babel/plugin-transform-private-property-in-object", { loose: true }], | ||
], | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
module.exports = { | ||
preset: "react-native", | ||
moduleFileExtensions: ["ts", "tsx", "js", "jsx", "json", "node"], | ||
transformIgnorePatterns: [ | ||
"node_modules/(?!(" + | ||
"react-native|" + | ||
"@react-native|" + | ||
"@xenova/transformers|" + | ||
"text-encoding-polyfill" + | ||
")/)", | ||
], | ||
setupFiles: [ | ||
"./node_modules/react-native/jest/setup.js", | ||
"./src/__tests__/setup.js", | ||
], | ||
testRegex: "(/__tests__/.*(?<!setup)\\.(test|spec))\\.[jt]sx?$", | ||
testEnvironment: "node", | ||
transform: { | ||
"^.+\\.(js|jsx|ts|tsx)$": [ | ||
"babel-jest", | ||
{ configFile: "./babel.config.js" }, | ||
], | ||
}, | ||
globals: { | ||
"ts-jest": { | ||
babelConfig: true, | ||
tsconfig: "tsconfig.json", | ||
}, | ||
}, | ||
collectCoverage: true, | ||
coverageDirectory: "coverage", | ||
coverageReporters: ["text", "lcov"], | ||
collectCoverageFrom: [ | ||
"src/**/*.{js,jsx,ts,tsx}", | ||
"!src/**/*.d.ts", | ||
"!src/__tests__/**", | ||
], | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// Mock text-encoding-polyfill | ||
jest.mock("text-encoding-polyfill", () => ({})); | ||
|
||
// Mock fetch for config loading | ||
global.fetch = jest.fn(() => | ||
Promise.resolve({ | ||
arrayBuffer: () => | ||
Promise.resolve( | ||
Uint8Array.from( | ||
JSON.stringify({ | ||
eos_token_id: 2, | ||
num_key_value_heads: 32, | ||
hidden_size: 4096, | ||
num_attention_heads: 32, | ||
num_hidden_layers: 32, | ||
}) | ||
.split("") | ||
.map((c) => c.charCodeAt(0)), | ||
).buffer, | ||
), | ||
}), | ||
); | ||
|
||
// Mock InferenceSession | ||
jest.mock("onnxruntime-react-native", () => ({ | ||
InferenceSession: { | ||
create: jest.fn().mockResolvedValue({ | ||
run: jest.fn().mockResolvedValue({ | ||
logits: { | ||
data: new Float32Array([0.1, 0.2, 0.3, 0.4]), | ||
dims: [1, 1, 4], | ||
type: "float32", | ||
}, | ||
}), | ||
release: jest.fn(), | ||
}), | ||
}, | ||
env: { | ||
logLevel: "error", | ||
}, | ||
Tensor: jest.fn().mockImplementation((type, data, dims) => ({ | ||
type, | ||
data, | ||
dims, | ||
size: data.length, | ||
dispose: jest.fn(), | ||
})), | ||
})); | ||
|
||
// Mock transformers | ||
jest.mock("@xenova/transformers", () => ({ | ||
env: { | ||
allowRemoteModels: true, | ||
allowLocalModels: false, | ||
}, | ||
AutoTokenizer: { | ||
from_pretrained: jest.fn().mockResolvedValue({ | ||
decode: jest.fn((_tokens, _options) => "decoded text"), | ||
}), | ||
}, | ||
})); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import { TextGeneration } from "../models/text-generation"; | ||
import { Tensor } from "onnxruntime-react-native"; | ||
|
||
// Mock onnxruntime-react-native | ||
jest.mock("onnxruntime-react-native", () => ({ | ||
Tensor: jest.fn().mockImplementation((type, data, dims) => ({ | ||
type, | ||
data, | ||
dims, | ||
size: data.length, | ||
})), | ||
})); | ||
|
||
// Create a test-specific subclass to access protected properties | ||
class TestableTextGeneration extends TextGeneration { | ||
public getSession() { | ||
return this.sess; | ||
} | ||
|
||
public setSession(session: any) { | ||
this.sess = session; | ||
} | ||
|
||
public getFeed() { | ||
return this.feed; | ||
} | ||
|
||
public getEos() { | ||
return this.eos; | ||
} | ||
} | ||
|
||
describe("TextGeneration Model", () => { | ||
let model: TestableTextGeneration; | ||
let mockRunCount: number; | ||
|
||
beforeEach(() => { | ||
mockRunCount = 0; | ||
model = new TestableTextGeneration(); | ||
}); | ||
|
||
describe("initializeFeed", () => { | ||
it("should reset output tokens", () => { | ||
model.outputTokens = [1n, 2n, 3n]; | ||
model.initializeFeed(); | ||
expect(model.outputTokens).toEqual([]); | ||
}); | ||
}); | ||
|
||
describe("generate", () => { | ||
const mockCallback = jest.fn(); | ||
const mockTokens = [1n, 2n]; // Initial tokens | ||
|
||
beforeEach(() => { | ||
mockCallback.mockClear(); | ||
mockRunCount = 0; | ||
}); | ||
|
||
it("should generate tokens until EOS token is found", async () => { | ||
model.setSession({ | ||
run: jest.fn().mockImplementation(() => { | ||
mockRunCount++; | ||
return Promise.resolve({ | ||
logits: { | ||
data: new Float32Array([0.1, 0.2, 0.3, 2.0]), // highest value at index 3 | ||
dims: [1, 1, 4], | ||
type: "float32", | ||
}, | ||
}); | ||
}), | ||
}); | ||
|
||
const result = await model.generate(mockTokens, mockCallback, { | ||
maxTokens: 10, | ||
}); | ||
expect(result.length).toBeGreaterThan(0); | ||
expect(mockCallback).toHaveBeenCalled(); | ||
}); | ||
|
||
it("should respect maxTokens limit", async () => { | ||
const maxTokens = 5; | ||
model.setSession({ | ||
run: jest.fn().mockImplementation(() => { | ||
mockRunCount++; | ||
return Promise.resolve({ | ||
logits: { | ||
data: new Float32Array([0.1, 0.2, 0.3, 0.1]), // will generate token 2 (index with highest value) | ||
dims: [1, 1, 4], | ||
type: "float32", | ||
}, | ||
}); | ||
}), | ||
}); | ||
|
||
const result = await model.generate(mockTokens, mockCallback, { | ||
maxTokens, | ||
}); | ||
// Initial tokens (2) + generated tokens should not exceed maxTokens (5) | ||
expect(result.length).toBeLessThanOrEqual(maxTokens); | ||
expect(mockRunCount).toBeLessThanOrEqual(maxTokens - mockTokens.length); | ||
}); | ||
|
||
it("should throw error if session is undefined", async () => { | ||
model.setSession(undefined); | ||
await expect( | ||
model.generate(mockTokens, mockCallback, { maxTokens: 10 }), | ||
).rejects.toThrow("Session is undefined"); | ||
}); | ||
|
||
it("should create correct tensors for input", async () => { | ||
model.setSession({ | ||
run: jest.fn().mockResolvedValue({ | ||
logits: { | ||
data: new Float32Array([0.1, 0.2, 0.3, 0.4]), | ||
dims: [1, 1, 4], | ||
type: "float32", | ||
}, | ||
}), | ||
}); | ||
|
||
await model.generate(mockTokens, mockCallback, { maxTokens: 10 }); | ||
expect(Tensor).toHaveBeenCalledWith("int64", expect.any(BigInt64Array), [ | ||
1, | ||
mockTokens.length, | ||
]); | ||
}); | ||
|
||
it("should handle generation with attention mask", async () => { | ||
model.setSession({ | ||
run: jest.fn().mockResolvedValue({ | ||
logits: { | ||
data: new Float32Array([0.1, 0.2, 0.3, 0.4]), | ||
dims: [1, 1, 4], | ||
type: "float32", | ||
}, | ||
}), | ||
}); | ||
|
||
const result = await model.generate(mockTokens, mockCallback, { | ||
maxTokens: 10, | ||
}); | ||
const feed = model.getFeed(); | ||
expect(feed.attention_mask).toBeDefined(); | ||
expect(result).toBeDefined(); | ||
}); | ||
}); | ||
|
||
describe("release", () => { | ||
it("should release session resources", async () => { | ||
const mockSession = { | ||
release: jest.fn().mockResolvedValue(undefined), | ||
}; | ||
model.setSession(mockSession); | ||
|
||
await model.release(); | ||
expect(mockSession.release).toHaveBeenCalled(); | ||
expect(model.getSession()).toBeUndefined(); | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import TextGenerationPipeline from "../pipelines/text-generation"; | ||
import type { PreTrainedTokenizer } from "@xenova/transformers"; | ||
|
||
// Mock the transformers library | ||
jest.mock("@xenova/transformers", () => { | ||
// Create a mock tokenizer function with the correct type | ||
const mockTokenizerFn = Object.assign( | ||
jest | ||
.fn<Promise<{ input_ids: bigint[] }>, [string, any]>() | ||
.mockResolvedValue({ input_ids: [1n, 2n] }), | ||
{ | ||
decode: jest.fn((_tokens: bigint[], _options: unknown) => "decoded text"), | ||
}, | ||
) as unknown as PreTrainedTokenizer; | ||
|
||
return { | ||
env: { | ||
allowRemoteModels: true, | ||
allowLocalModels: false, | ||
}, | ||
AutoTokenizer: { | ||
from_pretrained: jest.fn().mockResolvedValue(mockTokenizerFn), | ||
}, | ||
}; | ||
}); | ||
|
||
// Mock the model | ||
jest.mock("../models/text-generation", () => { | ||
return { | ||
TextGeneration: jest.fn().mockImplementation(() => ({ | ||
initializeFeed: jest.fn(), | ||
generate: jest.fn().mockImplementation((tokens) => { | ||
// Return tokens without calling callback to avoid double decoding | ||
return Promise.resolve(tokens); | ||
}), | ||
load: jest.fn(), | ||
release: jest.fn(), | ||
outputTokens: [], | ||
})), | ||
}; | ||
}); | ||
|
||
describe("TextGenerationPipeline", () => { | ||
beforeEach(() => { | ||
jest.clearAllMocks(); | ||
// Reset module state | ||
jest.isolateModules(() => { | ||
require("../pipelines/text-generation"); | ||
}); | ||
}); | ||
|
||
describe("init", () => { | ||
it("should initialize with default options", async () => { | ||
await TextGenerationPipeline.init("test-model", "test-path"); | ||
expect( | ||
require("@xenova/transformers").AutoTokenizer.from_pretrained, | ||
).toHaveBeenCalledWith("test-model"); | ||
}); | ||
|
||
it("should initialize with custom options", async () => { | ||
await TextGenerationPipeline.init("test-model", "test-path", { | ||
show_special: true, | ||
max_tokens: 100, | ||
}); | ||
expect( | ||
require("@xenova/transformers").AutoTokenizer.from_pretrained, | ||
).toHaveBeenCalledWith("test-model"); | ||
}); | ||
}); | ||
|
||
describe("generate", () => { | ||
beforeEach(async () => { | ||
await TextGenerationPipeline.init("test-model", "test-path"); | ||
}); | ||
|
||
it("should generate text from prompt", async () => { | ||
const result = await TextGenerationPipeline.generate("test prompt"); | ||
expect(result).toBe("decoded text"); | ||
}); | ||
|
||
it("should call callback with generated text", async () => { | ||
const callback = jest.fn(); | ||
await TextGenerationPipeline.generate("test prompt", callback); | ||
expect(callback).toHaveBeenCalledWith("decoded text"); | ||
}); | ||
|
||
it("should throw error if not initialized", async () => { | ||
// Reset module state to clear tokenizer | ||
jest.resetModules(); | ||
const freshPipeline = require("../pipelines/text-generation").default; | ||
await expect(freshPipeline.generate("test")).rejects.toThrow( | ||
"Tokenizer undefined, please initialize first.", | ||
); | ||
}); | ||
}); | ||
|
||
describe("release", () => { | ||
it("should release resources", async () => { | ||
await TextGenerationPipeline.init("test-model", "test-path"); | ||
await TextGenerationPipeline.release(); | ||
// Reset module state to clear tokenizer | ||
jest.resetModules(); | ||
const freshPipeline = require("../pipelines/text-generation").default; | ||
await expect(freshPipeline.generate("test")).rejects.toThrow( | ||
"Tokenizer undefined, please initialize first.", | ||
); | ||
}); | ||
}); | ||
}); |
Oops, something went wrong.