Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance generateContentStream with streamCallbacks support. Fixes #322 #401 #403

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ for complete code.
npm install @google/generative-ai
```

1. Initialize the model
2. Initialize the model

```js
const { GoogleGenerativeAI } = require("@google/generative-ai");
Expand All @@ -44,7 +44,7 @@ const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({ model: "gemini-1.5-flash" });
```

1. Run a prompt
3. Run a prompt

```js
const prompt = "Does this look store-bought or homemade?";
Expand All @@ -59,6 +59,24 @@ const result = await model.generateContent([prompt, image]);
console.log(result.response.text());
```

## Elastic Embedding Sizes

The SDK supports elastic embedding sizes for text embedding models. You can specify the dimension size when creating embeddings:

```js
const model = genAI.getGenerativeModel({ model: "text-embedding-004" });

// Get an embedding with 128 dimensions instead of the default 768
const result = await model.embedContent({
content: { role: "user", parts: [{ text: "Hello world!" }] },
dimensions: 128
});

console.log("Embedding size:", result.embedding.values.length); // 128
```

Supported dimension sizes are: 128, 256, 384, 512, and 768 (default).

## Try out a sample app

This repository contains sample Node and web apps demonstrating how the SDK can
Expand All @@ -69,17 +87,17 @@ access and utilize the Gemini model for various use cases.
1. Check out this repository. \
`git clone https://github.com/google/generative-ai-js`

1. [Obtain an API key](https://makersuite.google.com/app/apikey) to use with
2. [Obtain an API key](https://makersuite.google.com/app/apikey) to use with
the Google AI SDKs.

2. cd into the `samples` folder and run `npm install`.
3. cd into the `samples` folder and run `npm install`.

3. Assign your API key to an environment variable: `export API_KEY=MY_API_KEY`.
4. Assign your API key to an environment variable: `export API_KEY=MY_API_KEY`.

4. Open the sample file you're interested in. Example: `text_generation.js`.
5. Open the sample file you're interested in. Example: `text_generation.js`.
In the `runAll()` function, comment out any samples you don't want to run.

5. Run the sample file. Example: `node text_generation.js`.
6. Run the sample file. Example: `node text_generation.js`.

## Documentation

Expand Down
10 changes: 9 additions & 1 deletion common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ export interface EmbedContentRequest {
// (undocumented)
content: Content;
// (undocumented)
dimensions?: number;
// (undocumented)
taskType?: TaskType;
// (undocumented)
title?: string;
Expand Down Expand Up @@ -525,7 +527,7 @@ export class GenerativeModel {
countTokens(request: CountTokensRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<CountTokensResponse>;
embedContent(request: EmbedContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<EmbedContentResponse>;
generateContent(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions): Promise<GenerateContentStreamResult>;
generateContentStream(request: GenerateContentRequest | string | Array<string | Part>, requestOptions?: SingleRequestOptions, streamCallbacks?: StreamCallbacks): Promise<GenerateContentStreamResult>;
// (undocumented)
generationConfig: GenerationConfig;
// (undocumented)
Expand Down Expand Up @@ -841,6 +843,12 @@ export interface StartChatParams extends BaseParams {
tools?: Tool[];
}

// @public
export interface StreamCallbacks {
onData?: (chunk: string) => void;
onDone?: (fullText: string) => void;
}

// @public
export type StringSchema = SimpleStringSchema | EnumStringSchema;

Expand Down
101 changes: 101 additions & 0 deletions samples/elastic_embeddings.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { GoogleGenerativeAI } from "@google/generative-ai";

async function embedContentWithDimensions() {

const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "text-embedding-004",
});


const result = await model.embedContent({
content: { role: "user", parts: [{ text: "Hello world!" }] },
dimensions: 128
});

console.log("Embedding size:", result.embedding.values.length);
console.log("First 5 dimensions:", result.embedding.values.slice(0, 5));
}

async function compareEmbeddingSizes() {
const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "text-embedding-004",
});

const text = "The quick brown fox jumps over the lazy dog";


const dimensions = [128, 256, 384, 512, 768];

console.log(`Comparing embedding sizes for text: "${text}"`);

for (const dim of dimensions) {
const result = await model.embedContent({
content: { role: "user", parts: [{ text }] },
dimensions: dim
});

console.log(`Dimensions: ${dim}, Actual size: ${result.embedding.values.length}`);
}
}

async function batchEmbedContentsWithDimensions() {
const genAI = new GoogleGenerativeAI(process.env.API_KEY);
const model = genAI.getGenerativeModel({
model: "text-embedding-004",
});

function textToRequest(text, dimensions) {
return {
content: { role: "user", parts: [{ text }] },
dimensions
};
}

const result = await model.batchEmbedContents({
requests: [
textToRequest("What is the meaning of life?", 128),
textToRequest("How much wood would a woodchuck chuck?", 256),
textToRequest("How does the brain work?", 384),
],
});

for (let i = 0; i < result.embeddings.length; i++) {
console.log(`Embedding ${i+1} size: ${result.embeddings[i].values.length}`);
}
}

async function runAll() {
try {
console.log("=== Embedding with dimensions ===");
await embedContentWithDimensions();

console.log("\n=== Comparing embedding sizes ===");
await compareEmbeddingSizes();

console.log("\n=== Batch embeddings with dimensions ===");
await batchEmbedContentsWithDimensions();
} catch (error) {
console.error("Error:", error);
}
}

runAll();
59 changes: 59 additions & 0 deletions samples/stream_callbacks.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

const { GoogleGenerativeAI } = require("@google/generative-ai");

// This sample demonstrates how to use streamCallbacks for receiving
// streaming responses without manually handling Node.js streams.

// Access your API key as an environment variable
const genAI = new GoogleGenerativeAI(process.env.API_KEY);

// For text-only input, use the gemini-pro model
async function runWithCallbacks() {
const model = genAI.getGenerativeModel({ model: "gemini-pro" });

console.log("Generating response with callbacks...");

await model.generateContentStream("Tell me a joke", {}, {
onData: (chunk) => process.stdout.write(chunk),
onDone: (fullText) => console.log("\n\nFull response:\n", fullText),
});
}

// Alternative usage with only onDone callback
async function runWithOnlyDoneCallback() {
const model = genAI.getGenerativeModel({ model: "gemini-pro" });

console.log("\nGenerating response with only onDone callback...");

await model.generateContentStream("Tell me another joke", {}, {
onDone: (fullText) => console.log("Full response:\n", fullText),
});
}

// Run the demos
async function main() {
try {
await runWithCallbacks();
await runWithOnlyDoneCallback();
} catch (error) {
console.error("Error:", error);
}
}

main();
37 changes: 36 additions & 1 deletion src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import {
StartChatParams,
Tool,
ToolConfig,
StreamCallbacks,
} from "../../types";
import { ChatSession } from "../methods/chat-session";
import { countTokens } from "../methods/count-tokens";
Expand Down Expand Up @@ -128,17 +129,23 @@ export class GenerativeModel {
* Fields set in the optional {@link SingleRequestOptions} parameter will
* take precedence over the {@link RequestOptions} values provided to
* {@link GoogleGenerativeAI.getGenerativeModel }.
*
* The optional {@link StreamCallbacks} parameter allows receiving text
* chunks via callbacks without manually handling Node.js streams.
* - onData: Called with each chunk of text as it arrives
* - onDone: Called with the full text when streaming is complete
*/
async generateContentStream(
request: GenerateContentRequest | string | Array<string | Part>,
requestOptions: SingleRequestOptions = {},
streamCallbacks?: StreamCallbacks
): Promise<GenerateContentStreamResult> {
const formattedParams = formatGenerateContentInput(request);
const generativeModelRequestOptions: SingleRequestOptions = {
...this._requestOptions,
...requestOptions,
};
return generateContentStream(
const result = await generateContentStream(
this.apiKey,
this.model,
{
Expand All @@ -152,6 +159,34 @@ export class GenerativeModel {
},
generativeModelRequestOptions,
);

// If streamCallbacks are provided, set up the handlers
if (streamCallbacks?.onData || streamCallbacks?.onDone) {
// Handle onData callback for each chunk
if (streamCallbacks.onData) {
const originalStream = result.stream;
result.stream = (async function* () {
let fullText = '';
for await (const chunk of originalStream) {
const text = chunk.text();
fullText += text;
streamCallbacks.onData?.(text);
yield chunk;
}
// Call onDone with the full text when complete
if (streamCallbacks.onDone) {
streamCallbacks.onDone(fullText);
}
})();
} else if (streamCallbacks.onDone) {
// If only onDone is provided, collect the full text
result.response.then(response => {
streamCallbacks.onDone?.(response.text());
});
}
}

return result;
}

/**
Expand Down
Loading