Skip to content

Commit

Permalink
introduce exampleSelector at deep layer
Browse files Browse the repository at this point in the history
  • Loading branch information
yuiseki committed Jan 7, 2024
1 parent 403aa18 commit 0906c80
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 278 deletions.
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"jsonv",
"Karnali",
"Kita",
"Kurunegala",
"kwargs",
"landuse",
"langchain",
Expand All @@ -20,6 +21,7 @@
"lightyellow",
"llms",
"Mandera",
"México",
"NDRRMA",
"NEOC",
"OPENAI",
Expand Down
25 changes: 22 additions & 3 deletions src/app/api/ai/deep/route.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
import { NextResponse } from "next/server";
import { OpenAI, OpenAIChat } from "langchain/llms/openai";
import { loadTridentDeepChain } from "@/utils/langchain/chains/deep";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";

export async function POST(request: Request) {
const res = await request.json();
const query = res.query;

const model = new OpenAIChat({ temperature: 0 });
const chain = loadTridentDeepChain({ llm: model });
const result = await chain.call({ text: query });
let embeddings: OpenAIEmbeddings;
let llm: OpenAIChat;
if (process.env.CLOUDFLARE_AI_GATEWAY) {
embeddings = new OpenAIEmbeddings({
configuration: {
baseURL: process.env.CLOUDFLARE_AI_GATEWAY + "/openai",
},
});
llm = new OpenAIChat({
configuration: {
baseURL: process.env.CLOUDFLARE_AI_GATEWAY + "/openai",
},
temperature: 0,
});
} else {
embeddings = new OpenAIEmbeddings();
llm = new OpenAIChat({ temperature: 0 });
}

const chain = await loadTridentDeepChain({ embeddings, llm });
const result = await chain.call({ input: query });

console.log("----- ----- -----");
console.log("----- start deep -----");
Expand Down
1 change: 0 additions & 1 deletion src/app/api/ai/inner/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ export async function POST(request: Request) {

let embeddings: OpenAIEmbeddings;
let llm: OpenAIChat;

if (process.env.CLOUDFLARE_AI_GATEWAY) {
embeddings = new OpenAIEmbeddings({
configuration: {
Expand Down
12 changes: 8 additions & 4 deletions src/utils/langchain/chains/deep/index.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import { LLMChain } from "langchain/chains";
import { TRIDENT_DEEP_PROMPT } from "./prompt";
import { loadTridentDeepPrompt } from "./prompt";
import { BaseLanguageModel } from "langchain/dist/base_language";
import { Embeddings } from "langchain/embeddings/base";

export const loadTridentDeepChain = ({
export const loadTridentDeepChain = async ({
embeddings,
llm,
}: {
embeddings: Embeddings;
llm: BaseLanguageModel;
}): LLMChain => {
}): Promise<LLMChain> => {
const prompt = await loadTridentDeepPrompt(embeddings);
const chain = new LLMChain({
llm: llm,
prompt: TRIDENT_DEEP_PROMPT,
prompt: prompt,
});
return chain;
};
Loading

1 comment on commit 0906c80

@vercel
Copy link

@vercel vercel bot commented on 0906c80 Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.