Skip to content

Commit

Permalink
fix(js/plugins/ollama): fix ollama embedding config (#1246)
Browse files Browse the repository at this point in the history
* fix(js/plugins/ollama): fix ollama embedding config

* docs(js/plugins/ollama): add embedding docs

* Update docs/plugins/ollama.md

Co-authored-by: Pavel Jbanov <[email protected]>

* Update docs/plugins/ollama.md

Co-authored-by: Pavel Jbanov <[email protected]>

---------

Co-authored-by: Pavel Jbanov <[email protected]>
  • Loading branch information
cabljac and pavelgj authored Nov 12, 2024
1 parent 757a088 commit 4593450
Show file tree
Hide file tree
Showing 12 changed files with 485 additions and 123 deletions.
44 changes: 35 additions & 9 deletions docs/plugins/ollama.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,22 @@ example:
ollama pull gemma
```

To use this plugin, specify it when you call `configureGenkit()`.
To use the plugin, specify it when you call genkit:

```js
```typescript
import { genkit } from 'genkit';
import { ollama } from 'genkitx-ollama';

export default configureGenkit({
const ai = genkit({
plugins: [
ollama({
models: [
{
name: 'gemma',
type: 'generate', // type: 'chat' | 'generate' | undefined
type: 'generate', // Options: 'chat' | 'generate' |
},
],
serverAddress: 'http://127.0.0.1:11434', // default local address
serverAddress: 'http://127.0.0.1:11434', // default serverAddress to use
}),
],
});
Expand Down Expand Up @@ -64,7 +65,7 @@ the Google Auth library:
```js
import { GoogleAuth } from 'google-auth-library';
import { ollama, OllamaPluginParams } from 'genkitx-ollama';
import { configureGenkit, isDevEnv } from '@genkit-ai/core';
import { genkit, isDevEnv } from '@genkit-ai/core';

const ollamaCommon = { models: [{ name: 'gemma:2b' }] };

Expand All @@ -82,7 +83,7 @@ const ollamaProd = {
},
} as OllamaPluginParams;

export default configureGenkit({
const ai = genkit({
plugins: [
ollama(isDevEnv() ? ollamaDev : ollamaProd),
],
Expand Down Expand Up @@ -117,8 +118,33 @@ This plugin doesn't statically export model references. Specify one of the
models you configured using a string identifier:

```js
const llmResponse = await generate({
model: 'ollama/gemma',
const llmResponse = await ai.generate({
model: 'ollama/gemma:2b',
prompt: 'Tell me a joke.',
});
```

## Embedders
The Ollama plugin supports embeddings, which can be used for similarity searches and other NLP tasks.

```typescript
const ai = genkit({
plugins: [
ollama({
serverAddress: 'http://localhost:11434',
embedders: [{ name: 'nomic-embed-text', dimensions: 768 }],
}),
],
});

async function getEmbedding() {
const embedding = await ai.embed({
embedder: 'ollama/nomic-embed-text',
content: 'Some text to embed!',
})

return embedding;
}

getEmbedding().then((e) => console.log(e))
```
1 change: 1 addition & 0 deletions js/plugins/ollama/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"devDependencies": {
"@types/node": "^20.11.16",
"npm-run-all": "^4.1.5",
"ollama": "^0.5.9",
"rimraf": "^6.0.1",
"tsup": "^8.0.2",
"tsx": "^4.7.0",
Expand Down
120 changes: 74 additions & 46 deletions js/plugins/ollama/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,50 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { Genkit } from 'genkit';
import { logger } from 'genkit/logging';
import { OllamaPluginParams } from './index.js';
import { Document, Genkit } from 'genkit';
import { EmbedRequest, EmbedResponse } from 'ollama';
import { DefineOllamaEmbeddingParams, RequestHeaders } from './types.js';

interface OllamaEmbeddingPrediction {
embedding: number[];
}
async function toOllamaEmbedRequest(
modelName: string,
dimensions: number,
documents: Document[],
serverAddress: string,
requestHeaders?: RequestHeaders
): Promise<{
url: string;
requestPayload: EmbedRequest;
headers: Record<string, string>;
}> {
const requestPayload: EmbedRequest = {
model: modelName,
input: documents.map((doc) => doc.text),
};

// Determine headers
const extraHeaders = requestHeaders
? typeof requestHeaders === 'function'
? await requestHeaders({
serverAddress,
model: {
name: modelName,
dimensions,
},
embedRequest: requestPayload,
})
: requestHeaders
: {};

interface DefineOllamaEmbeddingParams {
name: string;
modelName: string;
dimensions: number;
options: OllamaPluginParams;
const headers = {
'Content-Type': 'application/json',
...extraHeaders, // Add any dynamic headers
};

return {
url: `${serverAddress}/api/embed`,
requestPayload,
headers,
};
}

export function defineOllamaEmbedder(
Expand All @@ -34,50 +65,47 @@ export function defineOllamaEmbedder(
) {
return ai.defineEmbedder(
{
name,
name: `ollama/${name}`,
info: {
label: 'Ollama Embedding - ' + modelName,
label: 'Ollama Embedding - ' + name,
dimensions,
supports: {
// TODO: do any ollama models support other modalities?
input: ['text'],
},
},
},
async (input) => {
const serverAddress = options.serverAddress;
const responses = await Promise.all(
input.map(async (i) => {
const requestPayload = {
model: modelName,
prompt: i.text,
};
let res: Response;
try {
res = await fetch(`${serverAddress}/api/embeddings`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(requestPayload),
});
} catch (e) {
logger.error('Failed to fetch Ollama embedding');
throw new Error(`Error fetching embedding from Ollama: ${e}`);
}
if (!res.ok) {
logger.error('Failed to fetch Ollama embedding');
throw new Error(
`Error fetching embedding from Ollama: ${res.statusText}`
);
}
const responseData = (await res.json()) as OllamaEmbeddingPrediction;
return responseData;
})
async (input, config) => {
const serverAddress = config?.serverAddress || options.serverAddress;

const { url, requestPayload, headers } = await toOllamaEmbedRequest(
modelName,
dimensions,
input,
serverAddress,
options.requestHeaders
);
return {
embeddings: responses,
};

const response: Response = await fetch(url, {
method: 'POST',
headers,
body: JSON.stringify(requestPayload),
});

if (!response.ok) {
throw new Error(
`Error fetching embedding from Ollama: ${response.statusText}`
);
}

const payload: EmbedResponse = await response.json();

const embeddings: { embedding: number[] }[] = [];

for (const embedding of payload.embeddings) {
embeddings.push({ embedding });
}
return { embeddings };
}
);
}
75 changes: 29 additions & 46 deletions js/plugins/ollama/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,24 @@ import {
} from 'genkit/model';
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
import { defineOllamaEmbedder } from './embeddings';
import {
ApiType,
ModelDefinition,
OllamaPluginParams,
RequestHeaders,
} from './types';

type ApiType = 'chat' | 'generate';

type RequestHeaders =
| Record<string, string>
| ((
params: { serverAddress: string; model: ModelDefinition },
request: GenerateRequest
) => Promise<Record<string, string> | void>);

type ModelDefinition = { name: string; type?: ApiType };
type EmbeddingModelDefinition = { name: string; dimensions: number };

export interface OllamaPluginParams {
models: ModelDefinition[];
embeddingModels?: EmbeddingModelDefinition[];

/**
* ollama server address.
*/
serverAddress: string;

requestHeaders?: RequestHeaders;
}
export { defineOllamaEmbedder };

export function ollama(params: OllamaPluginParams): GenkitPlugin {
return genkitPlugin('ollama', async (ai: Genkit) => {
const serverAddress = params?.serverAddress;
params.models.map((model) =>
const serverAddress = params.serverAddress;
params.models?.map((model) =>
ollamaModel(ai, model, serverAddress, params.requestHeaders)
);
params.embeddingModels?.map((model) =>
params.embedders?.map((model) =>
defineOllamaEmbedder(ai, {
name: `${ollama}/model.name`,
name: model.name,
modelName: model.name,
dimensions: model.dimensions,
options: params,
Expand All @@ -85,20 +69,20 @@ function ollamaModel(
},
async (input, streamingCallback) => {
const options: Record<string, any> = {};
if (input.config?.hasOwnProperty('temperature')) {
options.temperature = input.config?.temperature;
if (input.config?.temperature !== undefined) {
options.temperature = input.config.temperature;
}
if (input.config?.hasOwnProperty('topP')) {
options.top_p = input.config?.topP;
if (input.config?.topP !== undefined) {
options.top_p = input.config.topP;
}
if (input.config?.hasOwnProperty('topK')) {
options.top_k = input.config?.topK;
if (input.config?.topK !== undefined) {
options.top_k = input.config.topK;
}
if (input.config?.hasOwnProperty('stopSequences')) {
options.stop = input.config?.stopSequences?.join('');
if (input.config?.stopSequences !== undefined) {
options.stop = input.config.stopSequences.join('');
}
if (input.config?.hasOwnProperty('maxOutputTokens')) {
options.num_predict = input.config?.maxOutputTokens;
if (input.config?.maxOutputTokens !== undefined) {
options.num_predict = input.config.maxOutputTokens;
}
const type = model.type ?? 'chat';
const request = toOllamaRequest(
Expand Down Expand Up @@ -137,13 +121,12 @@ function ollamaModel(
);
} catch (e) {
const cause = (e as any).cause;
if (cause) {
if (
cause instanceof Error &&
cause.message?.includes('ECONNREFUSED')
) {
cause.message += '. Make sure ollama server is running.';
}
if (
cause &&
cause instanceof Error &&
cause.message?.includes('ECONNREFUSED')
) {
cause.message += '. Make sure the Ollama server is running.';
throw cause;
}
throw e;
Expand Down Expand Up @@ -225,11 +208,11 @@ function toOllamaRequest(
type: ApiType,
stream: boolean
) {
const request = {
const request: any = {
model: name,
options,
stream,
} as any;
};
if (type === 'chat') {
const messages: Message[] = [];
input.messages.forEach((m) => {
Expand Down
Loading

0 comments on commit 4593450

Please sign in to comment.