From b47418f9fdc950be67021d0017be976d9fc288e2 Mon Sep 17 00:00:00 2001 From: yuiseki Date: Tue, 31 Dec 2024 15:36:48 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=A9=E3=83=BC=E3=83=A1=E3=83=B3=E5=B1=8B?= =?UTF-8?q?=E3=82=A8=E3=83=BC=E3=82=B8=E3=82=A7=E3=83=B3=E3=83=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- package.json | 2 +- scripts/agent.ts | 22 +++++++++++++++--- scripts/tool.ts | 22 ++++++++++++++++++ .../tools/osm/overpass/tokyo_ramen/index.ts | 23 ++++++++++++------- 4 files changed, 57 insertions(+), 12 deletions(-) create mode 100644 scripts/tool.ts diff --git a/package.json b/package.json index 15e67e20..514ea327 100644 --- a/package.json +++ b/package.json @@ -13,7 +13,7 @@ "test:counter": "node --loader ts-node/esm ./tests/counter.ts", "test:sortByArea": "node --loader ts-node/esm ./tests/sortByArea.ts", "test:sortByDistance": "node --loader ts-node/esm ./tests/sortByDistance.ts", - "agents": "tsx ./scripts/agents/index.ts", + "agent": "tsx --network-family-autoselection-attempt-timeout=500 ./scripts/agent.ts", "site:api.reliefweb.int:fetch": "node --loader ts-node/esm ./scripts/api.reliefweb.int/fetch.ts", "site:api.reliefweb.int:update": "node --loader ts-node/esm ./scripts/api.reliefweb.int/update_disasters.ts", "site:api.reliefweb.int:extract": "node --loader ts-node/esm ./scripts/api.reliefweb.int/extract.ts", diff --git a/scripts/agent.ts b/scripts/agent.ts index 6691c556..04872b5f 100644 --- a/scripts/agent.ts +++ b/scripts/agent.ts @@ -1,8 +1,13 @@ import { ChatOllama } from "@langchain/ollama"; import { HumanMessage } from "@langchain/core/messages"; +import { Tool } from "@langchain/core/tools"; +import { BaseChatModel } from "@langchain/core/language_models/chat_models"; // langgraph -import { loadWikipediaAgent } from "./agents/wikipedia.ts"; +import { createReactAgent } from "@langchain/langgraph/prebuilt"; + +// tool +import { OverpassTokyoRamenCount } from "../src/utils/langchain/tools/osm/overpass/tokyo_ramen/index.ts"; const model = new ChatOllama({ // 速いがツールを使わずに返答しちゃう @@ -22,12 +27,23 @@ const model = new ChatOllama({ temperature: 0, }); -const agent = await loadWikipediaAgent(model); +export const loadAgent = async (model: BaseChatModel) => { + const tools: Array = [new OverpassTokyoRamenCount()]; + const prompt = + "You are a specialist of ramen shops. Be sure to use overpass-tokyo-ramen-count tool and reply based on the results. You have up to 10 chances to use tool."; + return createReactAgent({ + llm: model, + tools: tools, + stateModifier: prompt, + }); +}; + +const agent = await loadAgent(model); // Use the agent const stream = await agent.stream( { - messages: [new HumanMessage("Who is the president of the United States?")], + messages: [new HumanMessage("東京都台東区のラーメン屋の数を教えて")], }, { streamMode: "values", diff --git a/scripts/tool.ts b/scripts/tool.ts new file mode 100644 index 00000000..1a8cef21 --- /dev/null +++ b/scripts/tool.ts @@ -0,0 +1,22 @@ +import { OverpassTokyoRamenCount } from "../src/utils/langchain/tools/osm/overpass/tokyo_ramen"; +import { AIMessage } from "@langchain/core/messages"; + +import { ToolNode } from "@langchain/langgraph/prebuilt"; + +const tools = [new OverpassTokyoRamenCount()]; +const toolNode = new ToolNode(tools); + +const messageWithSingleToolCall = new AIMessage({ + content: "", + tool_calls: [ + { + name: "overpass-tokyo-ramen-count", + args: { input: "台東区" }, + id: "tool_call_id", + type: "tool_call", + }, + ], +}); + +const res = await toolNode.invoke({ messages: [messageWithSingleToolCall] }); +console.log(res); diff --git a/src/utils/langchain/tools/osm/overpass/tokyo_ramen/index.ts b/src/utils/langchain/tools/osm/overpass/tokyo_ramen/index.ts index 678abcbb..2b98a91f 100644 --- a/src/utils/langchain/tools/osm/overpass/tokyo_ramen/index.ts +++ b/src/utils/langchain/tools/osm/overpass/tokyo_ramen/index.ts @@ -2,10 +2,10 @@ import { Tool } from "langchain/tools"; export class OverpassTokyoRamenCount extends Tool { name = "overpass-tokyo-ramen-count"; - description = `useful for when you need to count number of ramen shops by a name of area. Input: a name of area.`; + description = `useful for when you need to count number of ramen shops by a name of area. Input: a name of area in Tokyo in Japanese.`; async _call(input: string) { - console.debug("Tool: OverpassTokyoRamenCount, input:", input); + // console.debug("Tool: OverpassTokyoRamenCount, input:", input); try { const overpassQuery = `[out:json][timeout:30000]; area["name"="東京都"]->.outer; @@ -15,9 +15,15 @@ area["name"="${input}"]->.inner; ); out geom;`; const queryString = `data=${encodeURIComponent(overpassQuery)}`; - const overpassApiUrl = `https://z.overpass-api.de/api/interpreter?${queryString}`; - const res = await fetch(overpassApiUrl); + const overpassApiUrl = `https://overpass-api.de/api/interpreter`; + const res = await fetch(overpassApiUrl, { + method: "POST", + body: queryString, + headers: { + "Content-Type": "application/json", + }, + }); const json = await res.json(); if (json.elements.length === 0) { @@ -25,12 +31,13 @@ out geom;`; } const answer = json.elements.length; - console.debug("Tool: OverpassTokyoRamenCount, answer:"); - console.debug(answer); - console.debug(""); + // console.debug("Tool: OverpassTokyoRamenCount, answer:"); + // console.debug(answer); + // console.debug(""); return answer; } catch (error) { - return "I don't know."; + console.error("Tool: OverpassTokyoRamenCount, error:", error); + return "Error. Please try again."; } } }