Skip to content

Commit

Permalink
support gemini-2.0-flash-thinking-exp model
Browse files Browse the repository at this point in the history
  • Loading branch information
zuisong committed Jan 24, 2025
1 parent 9850312 commit 98a4774
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 60 deletions.
52 changes: 41 additions & 11 deletions dist/main_bun.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,7 @@ function openAiMessageToGeminiMessage(messages) {
return result;
}
function genModel(req) {
const defaultModel = (m) => {
if (m.startsWith("gemini")) {
return m;
}
return "gemini-1.5-flash-latest";
};
const model = ModelMapping[req.model] ?? defaultModel(req.model);
const model = GeminiModel.modelMapping(req.model);
let functions = req.tools?.filter((it) => it.type === "function")?.map((it) => it.function) ?? [];
functions = functions.concat((req.functions ?? []).map((it) => ({ strict: null, ...it })));
const responseMimeType = req.response_format?.type === "json_object" ? "application/json" : "text/plain";
Expand All @@ -118,7 +112,10 @@ function genModel(req) {
maxOutputTokens: req.max_tokens ?? void 0,
temperature: req.temperature ?? void 0,
topP: req.top_p ?? void 0,
responseMimeType
responseMimeType,
thinkingConfig: !model.isThinkingModel() ? void 0 : {
includeThoughts: true
}
},
tools: functions.length === 0 ? void 0 : [
{
Expand All @@ -137,6 +134,34 @@ function genModel(req) {
};
return [model, generateContentRequest];
}
var GeminiModel = class _GeminiModel {
static modelMapping(model) {
const modelName = ModelMapping[model] ?? _GeminiModel.defaultModel(model);
return new _GeminiModel(modelName);
}
model;
constructor(model) {
this.model = model;
}
isThinkingModel() {
return this.model.includes("thinking");
}
apiVersion() {
if (this.isThinkingModel()) {
return "v1alpha";
}
return "v1beta";
}
toString() {
return this.model;
}
static defaultModel(m) {
if (m.startsWith("gemini")) {
return m;
}
return "gemini-1.5-flash-latest";
}
};
var ModelMapping = {
"gpt-3.5-turbo": "gemini-1.5-flash-8b-latest",
"gpt-4": "gemini-1.5-pro-latest",
Expand Down Expand Up @@ -405,7 +430,7 @@ var RequestUrl = class {
this.apiParam = apiParam;
}
toURL() {
const api_version = "v1beta";
const api_version = this.model.apiVersion();
const url = new URL(`${BASE_URL}/${api_version}/models/${this.model}:${this.task}`);
url.searchParams.append("key", this.apiParam.apikey);
if (this.stream) {
Expand Down Expand Up @@ -676,7 +701,12 @@ async function embeddingProxyHandler(rawReq) {
log?.warn("request", embedContentRequest);
let geminiResp = [];
try {
for await (const it of generateContent("embedContent", apiParam, "text-embedding-004", embedContentRequest)) {
for await (const it of generateContent(
"embedContent",
apiParam,
new GeminiModel("text-embedding-004"),
embedContentRequest
)) {
const data = it.embedding?.values;
geminiResp = data;
break;
Expand Down Expand Up @@ -745,7 +775,7 @@ app.post("/v1/chat/completions", chatProxyHandler);
app.post("/v1/embeddings", embeddingProxyHandler);
app.get("/v1/models", () => Response.json(models()));
app.get("/v1/models/:model", (c) => Response.json(modelDetail(c.params.model)));
app.post(":model_version/models/:model_and_action", geminiProxy);
app.post("/:model_version/models/:model_and_action", geminiProxy);
app.all("*", () => new Response("Page Not Found", { status: 404 }));

// main_bun.ts
Expand Down
52 changes: 41 additions & 11 deletions dist/main_cloudflare-workers.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,7 @@ function openAiMessageToGeminiMessage(messages) {
return result;
}
function genModel(req) {
const defaultModel = (m) => {
if (m.startsWith("gemini")) {
return m;
}
return "gemini-1.5-flash-latest";
};
const model = ModelMapping[req.model] ?? defaultModel(req.model);
const model = GeminiModel.modelMapping(req.model);
let functions = req.tools?.filter((it) => it.type === "function")?.map((it) => it.function) ?? [];
functions = functions.concat((req.functions ?? []).map((it) => ({ strict: null, ...it })));
const responseMimeType = req.response_format?.type === "json_object" ? "application/json" : "text/plain";
Expand All @@ -118,7 +112,10 @@ function genModel(req) {
maxOutputTokens: req.max_tokens ?? void 0,
temperature: req.temperature ?? void 0,
topP: req.top_p ?? void 0,
responseMimeType
responseMimeType,
thinkingConfig: !model.isThinkingModel() ? void 0 : {
includeThoughts: true
}
},
tools: functions.length === 0 ? void 0 : [
{
Expand All @@ -137,6 +134,34 @@ function genModel(req) {
};
return [model, generateContentRequest];
}
var GeminiModel = class _GeminiModel {
static modelMapping(model) {
const modelName = ModelMapping[model] ?? _GeminiModel.defaultModel(model);
return new _GeminiModel(modelName);
}
model;
constructor(model) {
this.model = model;
}
isThinkingModel() {
return this.model.includes("thinking");
}
apiVersion() {
if (this.isThinkingModel()) {
return "v1alpha";
}
return "v1beta";
}
toString() {
return this.model;
}
static defaultModel(m) {
if (m.startsWith("gemini")) {
return m;
}
return "gemini-1.5-flash-latest";
}
};
var ModelMapping = {
"gpt-3.5-turbo": "gemini-1.5-flash-8b-latest",
"gpt-4": "gemini-1.5-pro-latest",
Expand Down Expand Up @@ -405,7 +430,7 @@ var RequestUrl = class {
this.apiParam = apiParam;
}
toURL() {
const api_version = "v1beta";
const api_version = this.model.apiVersion();
const url = new URL(`${BASE_URL}/${api_version}/models/${this.model}:${this.task}`);
url.searchParams.append("key", this.apiParam.apikey);
if (this.stream) {
Expand Down Expand Up @@ -676,7 +701,12 @@ async function embeddingProxyHandler(rawReq) {
log?.warn("request", embedContentRequest);
let geminiResp = [];
try {
for await (const it of generateContent("embedContent", apiParam, "text-embedding-004", embedContentRequest)) {
for await (const it of generateContent(
"embedContent",
apiParam,
new GeminiModel("text-embedding-004"),
embedContentRequest
)) {
const data = it.embedding?.values;
geminiResp = data;
break;
Expand Down Expand Up @@ -745,7 +775,7 @@ app.post("/v1/chat/completions", chatProxyHandler);
app.post("/v1/embeddings", embeddingProxyHandler);
app.get("/v1/models", () => Response.json(models()));
app.get("/v1/models/:model", (c) => Response.json(modelDetail(c.params.model)));
app.post(":model_version/models/:model_and_action", geminiProxy);
app.post("/:model_version/models/:model_and_action", geminiProxy);
app.all("*", () => new Response("Page Not Found", { status: 404 }));

// main_cloudflare-workers.ts
Expand Down
52 changes: 41 additions & 11 deletions dist/main_deno.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,7 @@ function openAiMessageToGeminiMessage(messages) {
return result;
}
function genModel(req) {
const defaultModel = (m) => {
if (m.startsWith("gemini")) {
return m;
}
return "gemini-1.5-flash-latest";
};
const model = ModelMapping[req.model] ?? defaultModel(req.model);
const model = GeminiModel.modelMapping(req.model);
let functions = req.tools?.filter((it) => it.type === "function")?.map((it) => it.function) ?? [];
functions = functions.concat((req.functions ?? []).map((it) => ({ strict: null, ...it })));
const responseMimeType = req.response_format?.type === "json_object" ? "application/json" : "text/plain";
Expand All @@ -118,7 +112,10 @@ function genModel(req) {
maxOutputTokens: req.max_tokens ?? void 0,
temperature: req.temperature ?? void 0,
topP: req.top_p ?? void 0,
responseMimeType
responseMimeType,
thinkingConfig: !model.isThinkingModel() ? void 0 : {
includeThoughts: true
}
},
tools: functions.length === 0 ? void 0 : [
{
Expand All @@ -137,6 +134,34 @@ function genModel(req) {
};
return [model, generateContentRequest];
}
var GeminiModel = class _GeminiModel {
static modelMapping(model) {
const modelName = ModelMapping[model] ?? _GeminiModel.defaultModel(model);
return new _GeminiModel(modelName);
}
model;
constructor(model) {
this.model = model;
}
isThinkingModel() {
return this.model.includes("thinking");
}
apiVersion() {
if (this.isThinkingModel()) {
return "v1alpha";
}
return "v1beta";
}
toString() {
return this.model;
}
static defaultModel(m) {
if (m.startsWith("gemini")) {
return m;
}
return "gemini-1.5-flash-latest";
}
};
var ModelMapping = {
"gpt-3.5-turbo": "gemini-1.5-flash-8b-latest",
"gpt-4": "gemini-1.5-pro-latest",
Expand Down Expand Up @@ -405,7 +430,7 @@ var RequestUrl = class {
this.apiParam = apiParam;
}
toURL() {
const api_version = "v1beta";
const api_version = this.model.apiVersion();
const url = new URL(`${BASE_URL}/${api_version}/models/${this.model}:${this.task}`);
url.searchParams.append("key", this.apiParam.apikey);
if (this.stream) {
Expand Down Expand Up @@ -676,7 +701,12 @@ async function embeddingProxyHandler(rawReq) {
log?.warn("request", embedContentRequest);
let geminiResp = [];
try {
for await (const it of generateContent("embedContent", apiParam, "text-embedding-004", embedContentRequest)) {
for await (const it of generateContent(
"embedContent",
apiParam,
new GeminiModel("text-embedding-004"),
embedContentRequest
)) {
const data = it.embedding?.values;
geminiResp = data;
break;
Expand Down Expand Up @@ -745,7 +775,7 @@ app.post("/v1/chat/completions", chatProxyHandler);
app.post("/v1/embeddings", embeddingProxyHandler);
app.get("/v1/models", () => Response.json(models()));
app.get("/v1/models/:model", (c) => Response.json(modelDetail(c.params.model)));
app.post(":model_version/models/:model_and_action", geminiProxy);
app.post("/:model_version/models/:model_and_action", geminiProxy);
app.all("*", () => new Response("Page Not Found", { status: 404 }));

// main_deno.ts
Expand Down
52 changes: 41 additions & 11 deletions dist/main_node.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -547,13 +547,7 @@ function openAiMessageToGeminiMessage(messages) {
return result;
}
function genModel(req) {
const defaultModel = (m) => {
if (m.startsWith("gemini")) {
return m;
}
return "gemini-1.5-flash-latest";
};
const model = ModelMapping[req.model] ?? defaultModel(req.model);
const model = GeminiModel.modelMapping(req.model);
let functions = req.tools?.filter((it) => it.type === "function")?.map((it) => it.function) ?? [];
functions = functions.concat((req.functions ?? []).map((it) => ({ strict: null, ...it })));
const responseMimeType = req.response_format?.type === "json_object" ? "application/json" : "text/plain";
Expand All @@ -563,7 +557,10 @@ function genModel(req) {
maxOutputTokens: req.max_tokens ?? void 0,
temperature: req.temperature ?? void 0,
topP: req.top_p ?? void 0,
responseMimeType
responseMimeType,
thinkingConfig: !model.isThinkingModel() ? void 0 : {
includeThoughts: true
}
},
tools: functions.length === 0 ? void 0 : [
{
Expand All @@ -582,6 +579,34 @@ function genModel(req) {
};
return [model, generateContentRequest];
}
var GeminiModel = class _GeminiModel {
static modelMapping(model) {
const modelName = ModelMapping[model] ?? _GeminiModel.defaultModel(model);
return new _GeminiModel(modelName);
}
model;
constructor(model) {
this.model = model;
}
isThinkingModel() {
return this.model.includes("thinking");
}
apiVersion() {
if (this.isThinkingModel()) {
return "v1alpha";
}
return "v1beta";
}
toString() {
return this.model;
}
static defaultModel(m) {
if (m.startsWith("gemini")) {
return m;
}
return "gemini-1.5-flash-latest";
}
};
var ModelMapping = {
"gpt-3.5-turbo": "gemini-1.5-flash-8b-latest",
"gpt-4": "gemini-1.5-pro-latest",
Expand Down Expand Up @@ -850,7 +875,7 @@ var RequestUrl = class {
this.apiParam = apiParam;
}
toURL() {
const api_version = "v1beta";
const api_version = this.model.apiVersion();
const url = new URL(`${BASE_URL}/${api_version}/models/${this.model}:${this.task}`);
url.searchParams.append("key", this.apiParam.apikey);
if (this.stream) {
Expand Down Expand Up @@ -1121,7 +1146,12 @@ async function embeddingProxyHandler(rawReq) {
log?.warn("request", embedContentRequest);
let geminiResp = [];
try {
for await (const it of generateContent("embedContent", apiParam, "text-embedding-004", embedContentRequest)) {
for await (const it of generateContent(
"embedContent",
apiParam,
new GeminiModel("text-embedding-004"),
embedContentRequest
)) {
const data = it.embedding?.values;
geminiResp = data;
break;
Expand Down Expand Up @@ -1190,7 +1220,7 @@ app.post("/v1/chat/completions", chatProxyHandler);
app.post("/v1/embeddings", embeddingProxyHandler);
app.get("/v1/models", () => Response.json(models()));
app.get("/v1/models/:model", (c) => Response.json(modelDetail(c.params.model)));
app.post(":model_version/models/:model_and_action", geminiProxy);
app.post("/:model_version/models/:model_and_action", geminiProxy);
app.all("*", () => new Response("Page Not Found", { status: 404 }));

// main_node.ts
Expand Down
2 changes: 1 addition & 1 deletion fly.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dockerfile = 'docker/bun.Dockerfile'
[http_service]
internal_port = 8000
force_https = true
auto_stop_machines = true
auto_stop_machines = "suspend"
auto_start_machines = true
min_machines_running = 0
processes = ['app']
Expand Down
2 changes: 1 addition & 1 deletion src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ app.post("/v1/chat/completions", chatProxyHandler)
app.post("/v1/embeddings", embeddingProxyHandler)
app.get("/v1/models", () => Response.json(models()))
app.get("/v1/models/:model", (c) => Response.json(modelDetail(c.params.model)))
app.post(":model_version/models/:model_and_action", geminiProxy)
app.post("/:model_version/models/:model_and_action", geminiProxy)
app.all("*", () => new Response("Page Not Found", { status: 404 }))

export { app }
Loading

0 comments on commit 98a4774

Please sign in to comment.