Skip to content

Commit

Permalink
Allow specifying arbitrary "gemini-*" model
Browse files Browse the repository at this point in the history
  • Loading branch information
johnd0e committed Jul 25, 2024
1 parent 35b3079 commit 81416be
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
7 changes: 7 additions & 0 deletions readme.MD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ _..or_:
set OPENAI_API_BASE=https://my-super-proxy.vercel.app/v1
```

## Models

In case the model specified in the request differs from "gemini-*",
the default [model] `gemini-1.5-pro` will be used.

[model]: https://ai.google.dev/gemini-api/docs/models/gemini

---

## Possible further development
Expand Down
11 changes: 6 additions & 5 deletions src/worker.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ const handleOPTIONS = async () => {
});
};

const DEFAULT_MODEL = "gemini-1.5-pro-latest";
const BASE_URL = "https://generativelanguage.googleapis.com";
const API_VERSION = "v1beta";
// https://github.com/google/generative-ai-js/blob/0931d2ce051215db72785d76fe3ae4e0bc3b5475/packages/main/src/requests/request.ts#L67
const API_CLIENT = "genai-js/0.16.0"; // npm view @google/generative-ai version
async function handleRequest(req, apiKey) {
const MODEL = "gemini-1.5-pro-latest";
const model = req?.model.startsWith("gemini-") ? req.model : DEFAULT_MODEL;
const TASK = req.stream ? "streamGenerateContent" : "generateContent";
let url = `${BASE_URL}/${API_VERSION}/models/${MODEL}:${TASK}`;
let url = `${BASE_URL}/${API_VERSION}/models/${model}:${TASK}`;
if (req.stream) { url += "?alt=sse"; }
let response;
try {
Expand Down Expand Up @@ -80,13 +81,13 @@ async function handleRequest(req, apiKey) {
.pipeThrough(new TransformStream({
transform: toOpenAiStream,
flush: toOpenAiStreamFlush,
MODEL, id, last: [],
model, id, last: [],
}))
.pipeThrough(new TextEncoderStream());
} else {
body = await response.text();
try {
body = await processResponse(JSON.parse(body).candidates, MODEL, id);
body = await processResponse(JSON.parse(body).candidates, model, id);
} catch (err) {
console.error(err);
response = { status: 500 };
Expand Down Expand Up @@ -284,7 +285,7 @@ function transformResponseStream (cand, stop, first) {
id: this.id,
object: "chat.completion.chunk",
created: Math.floor(Date.now()/1000),
model: this.MODEL,
model: this.model,
// system_fingerprint: "fp_69829325d0",
choices: [item],
};
Expand Down

0 comments on commit 81416be

Please sign in to comment.