Skip to content

Commit

Permalink
feat: add embeddings to client (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
Brandon Meyerowitz authored Jul 26, 2023
1 parent d519b26 commit d3c3a3b
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 42 deletions.
42 changes: 36 additions & 6 deletions client.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { rootApiUrl } from "./constants";
import {
APIErrorResponse,
APIResponse,
ClientOptions,
CompletionConfig,
CompletionResponse,
CompletionResult,
EmbeddingsConfig,
EmbeddingsResponse,
} from "./types";

class StreamConsumer {
Expand Down Expand Up @@ -41,7 +43,7 @@ class StreamConsumer {
if (!line.startsWith("data: ")) {
return this.next();
}
let data: APIResponse;
let data: CompletionResponse;
try {
data = JSON.parse(line.slice(6));
} catch (e) {
Expand Down Expand Up @@ -74,7 +76,7 @@ export class Client {
this.apiUrl = this.options._apiUrl || rootApiUrl;
}

private getBody(config: CompletionConfig) {
private getCompletionBody(config: CompletionConfig) {
return {
...this.options._extraParams,
projectId: config.projectId || this.options.projectId,
Expand All @@ -93,15 +95,25 @@ export class Client {
providerConfig: config.providerConfig,
};
}
private getEmbeddingsBody(config: EmbeddingsConfig) {
return {
...this.options._extraParams,
projectId: config.projectId || this.options.projectId,
apiKey: this.options.apiKey,
userId: config.userId,
input: config.input,
providerConfig: config.providerConfig,
};
}
private async fetchAPI(
path: string,
config: CompletionConfig,
stream = false,
): Promise<APIResponse | StreamConsumer> {
): Promise<CompletionResponse | EmbeddingsResponse | StreamConsumer> {
const res = await fetch(`${this.apiUrl}/${path}`, {
method: "POST",
body: JSON.stringify({
...this.getBody(config),
...this.getCompletionBody(config),
stream,
}),
headers: {
Expand All @@ -128,7 +140,7 @@ export class Client {
const completionsRes = (await this.fetchAPI(
"completions",
config,
)) as APIResponse;
)) as CompletionResponse;

return new CompletionResult(completionsRes);
}
Expand All @@ -138,4 +150,22 @@ export class Client {
): Promise<StreamConsumer> {
return (await this.fetchAPI("completions", config, true)) as StreamConsumer;
}

async createEmbedding(config: EmbeddingsConfig): Promise<EmbeddingsResponse> {
const res = await fetch(`${this.apiUrl}/embeddings`, {
method: "POST",
body: JSON.stringify(this.getEmbeddingsBody(config)),
headers: {
...this.options._extraHeaders,
"Content-Type": "application/json; charset=utf-8",
},
});

if (!res.ok) {
const resBody = await res.json();
throw new APIError(res.status, resBody);
}

return (await res.json()) as EmbeddingsResponse;
}
}
24 changes: 24 additions & 0 deletions examples/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import * as process from "process";

// import { Client } from "@commonbase/sdk";
import { Client } from "../index";

async function main() {
const client = new Client({
projectId: process.env.CB_PROJECT_ID,
});
const embeddingsResponse = await client.createEmbedding({
input: "Your text string",
providerConfig: {
provider: "cb-openai-eu",
params: {
type: "embeddings",
model: "text-embedding-ada-002",
},
},
});

console.log(embeddingsResponse.data);
}

main().catch(console.error);
95 changes: 59 additions & 36 deletions types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export class CompletionResult {
private readonly _rawResponse: APIResponse;
constructor(response: APIResponse) {
private readonly _rawResponse: CompletionResponse;
constructor(response: CompletionResponse) {
this._rawResponse = response;
}

Expand All @@ -23,7 +23,7 @@ export class CompletionResult {
return this._rawResponse.completed;
}

get _raw(): APIResponse {
get _raw(): CompletionResponse {
return this._rawResponse;
}
}
Expand Down Expand Up @@ -62,7 +62,7 @@ export type ChatContext = {
messages: ChatMessage[];
};

type APIResponseChoice = {
type CompletionResponseChoice = {
index: number;
finish_reason: string;
text: string;
Expand All @@ -78,16 +78,29 @@ type TruncationResult = {
iterations: number;
};

export type APIResponse = {
export type CompletionResponse = {
completed: boolean;
invocationId: string;
projectId: string;
type: string;
model: string;
choices: APIResponseChoice[];
choices: CompletionResponseChoice[];
variableTruncation?: TruncationResult;
};

export type EmbeddingsResponse = {
completed: boolean;
invocationId: string;
projectId: string;
type: string;
model: string;
data: {
object: "embedding";
index: number;
embedding: number[];
}[];
};

export type APIErrorResponse = {
error: string;
invocationId?: string;
Expand All @@ -102,36 +115,38 @@ export type TruncationConfig = {
name?: string;
};

type ProviderConfig =
| {
provider: "openai" | "cb-openai-eu";
params: {
type: "chat" | "text";
model?: string;
temperature?: number;
top_p?: number;
max_tokens?: number;
n?: number;
frequency_penalty?: number;
presence_penalty?: number;
stop?: string[] | string;
best_of?: number;
suffix?: string;
logprobs?: number;
};
}
| {
provider: "anthropic";
params: {
type: "chat" | undefined;
model?: string;
max_tokens_to_sample?: number;
temperature?: number;
stop_sequences?: string[];
top_k?: number;
top_p?: number;
};
};
type OpenAIProviderConfig = {
provider: "openai" | "cb-openai-eu";
params: {
type: "chat" | "text" | "embeddings";
model?: string;
temperature?: number;
top_p?: number;
max_tokens?: number;
n?: number;
frequency_penalty?: number;
presence_penalty?: number;
stop?: string[] | string;
best_of?: number;
suffix?: string;
logprobs?: number;
};
};

type AnthropicProviderConfig = {
provider: "anthropic";
params: {
type: "chat" | undefined;
model?: string;
max_tokens_to_sample?: number;
temperature?: number;
stop_sequences?: string[];
top_k?: number;
top_p?: number;
};
};

type ProviderConfig = OpenAIProviderConfig | AnthropicProviderConfig;

export type CompletionConfig = {
variables?: Record<string, string>;
Expand All @@ -142,3 +157,11 @@ export type CompletionConfig = {
prompt?: string;
providerConfig?: ProviderConfig;
};

export type EmbeddingsConfig = {
input: string;
projectId?: string;
apiKey?: string;
userId?: string;
providerConfig?: ProviderConfig;
};

0 comments on commit d3c3a3b

Please sign in to comment.