Skip to content

Commit

Permalink
kyo-llm: refactor tools to agents
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil committed Dec 10, 2023
1 parent 13c9ba9 commit 14fa543
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 288 deletions.
125 changes: 125 additions & 0 deletions kyo-llm/shared/src/main/scala/kyo/llm/agents.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package kyo.llm

import kyo._
import kyo.ios._
import kyo.seqs._
import kyo.tries._
import kyo.locals._
import kyo.llm.ais._
import scala.util._
import kyo.llm.contexts._
import kyo.concurrent.atomics._
import zio.schema.Schema
import zio.schema.codec.JsonCodec
import kyo.llm.util.JsonSchema
import scala.annotation.implicitNotFound

object agents {

abstract class Agent {

type Input
type Output

case class Info(
name: String,
description: String
)(implicit
val input: ValueSchema[Input],
val output: ValueSchema[Output]
) {
val schema = JsonSchema(input.get)
val decoder = JsonCodec.jsonDecoder(input.get)
val encoder = JsonCodec.jsonEncoder(output.get)
}

val info: Info

def run(input: Input): Output > AIs

private val local = Locals.init(Option.empty[AI])

protected def caller: AI > AIs =
local.get.map {
case Some(ai) => ai
case None => AIs.init
}

private[kyo] def handle(ai: AI, v: String): String > AIs =
info.decoder.decodeJson(v) match {
case Left(error) =>
AIs.fail(
"Invalid json input. **Correct any mistakes before retrying**. " + error
)
case Right(value) =>
local.let(Some(ai)) {
run(value.value)
}.map { v =>
info.encoder.encodeJson(Value(v)).toString()
}
}
}

object Agents {
private val local = Locals.init(Set.empty[Agent])

def get: Set[Agent] > AIs = local.get

def enable[T, S](p: Agent*)(v: => T > S): T > (AIs with S) =
local.get.map { set =>
local.let(set ++ p.toSeq)(v)
}

def disable[T, S](f: T > S): T > (AIs with S) =
local.let(Set.empty)(f)

private[kyo] def resultAgent[T](implicit
t: ValueSchema[T]
): (Agent, Option[T] > AIs) > AIs =
Atomics.initRef(Option.empty[T]).map { ref =>
val agent =
new Agent {
type Input = T
type Output = String

val info = Info(
"resultAgent",
"Call this agent with the result."
)

def run(input: T) =
ref.set(Some(input)).andThen("Result processed.")
}
(agent, ref.get)
}

private[kyo] def handle(ai: AI, agents: Set[Agent], calls: List[Call]): Unit > AIs =
Seqs.traverseUnit(calls) { call =>
agents.find(_.info.name == call.function) match {
case None =>
ai.agentMessage(call.id, "Agent not found: " + call)
case Some(agent) =>
AIs.ephemeral {
Agents.disable {
Tries.run[String, AIs] {
ai.agentMessage(
call.id,
p"""
Entering the agent execution flow. Further interactions
are automated and indirectly initiated by a human.
"""
).andThen {
agent.handle(ai, call.arguments)
}
}
}
}.map {
case Success(result) =>
ai.agentMessage(call.id, result)
case Failure(ex) =>
ai.agentMessage(call.id, "Failure:" + ex)
}
}
}
}
}
42 changes: 21 additions & 21 deletions kyo-llm/shared/src/main/scala/kyo/llm/ais.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import kyo._
import kyo.llm.completions._
import kyo.llm.configs._
import kyo.llm.contexts._
import kyo.llm.tools._
import kyo.llm.agents._
import kyo.concurrent.Joins
import kyo.concurrent.atomics._
import kyo.concurrent.fibers._
Expand Down Expand Up @@ -79,34 +79,34 @@ object ais {
def assistantMessage(msg: String, calls: List[Call] = Nil): Unit > AIs =
update(_.assistantMessage(msg, calls))

def toolMessage(callId: CallId, msg: String): Unit > AIs =
update(_.toolMessage(callId, msg))
def agentMessage(callId: CallId, msg: String): Unit > AIs =
update(_.agentMessage(callId, msg))

def ask(msg: String): String > AIs =
userMessage(msg).andThen(ask)

def ask: String > AIs = {
def eval(tools: Set[Tool[_, _]]): String > AIs =
fetch(tools).map { r =>
def eval(agents: Set[Agent]): String > AIs =
fetch(agents).map { r =>
r.calls match {
case Nil =>
r.content
case calls =>
Tools.handle(this, tools, calls)
.andThen(eval(tools))
Agents.handle(this, agents, calls)
.andThen(eval(agents))
}
}
Tools.get.map(eval)
Agents.get.map(eval)
}

def gen[T](msg: String)(implicit t: ValueSchema[T]): T > AIs =
userMessage(msg).andThen(gen[T])

def gen[T](implicit t: ValueSchema[T]): T > AIs = {
Tools.resultTool[T].map { case (resultTool, result) =>
Agents.resultAgent[T].map { case (resultAgent, result) =>
def eval(): T > AIs =
fetch(Set(resultTool), Some(resultTool)).map { r =>
Tools.handle(this, Set(resultTool), r.calls).andThen {
fetch(Set(resultAgent), Some(resultAgent)).map { r =>
Agents.handle(this, Set(resultAgent), r.calls).andThen {
result.map {
case Some(v) =>
v
Expand All @@ -123,34 +123,34 @@ object ais {
userMessage(msg).andThen(infer[T])

def infer[T](implicit t: ValueSchema[T]): T > AIs = {
Tools.resultTool[T].map { case (resultTool, result) =>
def eval(tools: Set[Tool[_, _]], constrain: Option[Tool[_, _]] = None): T > AIs =
fetch(tools, constrain).map { r =>
Agents.resultAgent[T].map { case (resultAgent, result) =>
def eval(agents: Set[Agent], constrain: Option[Agent] = None): T > AIs =
fetch(agents, constrain).map { r =>
r.calls match {
case Nil =>
eval(tools, Some(resultTool))
eval(agents, Some(resultAgent))
case calls =>
Tools.handle(this, tools, calls).andThen {
Agents.handle(this, agents, calls).andThen {
result.map {
case None =>
eval(tools)
eval(agents)
case Some(v) =>
v
}
}
}
}
Tools.get.map(p => eval(p + resultTool))
Agents.get.map(p => eval(p + resultAgent))
}
}

private def fetch(
tools: Set[Tool[_, _]],
constrain: Option[Tool[_, _]] = None
agents: Set[Agent],
constrain: Option[Agent] = None
): Completions.Result > AIs =
for {
ctx <- save
r <- Completions(ctx, tools, constrain)
r <- Completions(ctx, agents, constrain)
_ <- assistantMessage(r.content, r.calls)
} yield r
}
Expand Down
36 changes: 22 additions & 14 deletions kyo-llm/shared/src/main/scala/kyo/llm/completions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package kyo.llm
import kyo._
import kyo.llm.configs._
import kyo.llm.contexts._
import kyo.llm.tools._
import kyo.llm.agents._
import kyo.llm.ais._
import kyo.llm.util.JsonSchema
import kyo.ios._
Expand All @@ -28,12 +28,12 @@ object completions {

def apply(
ctx: Context,
tools: Set[Tool[_, _]] = Set.empty,
constrain: Option[Tool[_, _]] = None
agents: Set[Agent] = Set.empty,
constrain: Option[Agent] = None
): Result > (IOs with Requests) =
for {
config <- Configs.get
req = Request(ctx, config, tools, constrain)
req = Request(ctx, config, agents, constrain)
_ <- Logs.debug(req.toJsonPretty)
response <- config.completionsMeter.run(fetch(config, req))
_ <- Logs.debug(response.toJsonPretty)
Expand Down Expand Up @@ -123,7 +123,7 @@ object completions {
VisionEntry.Content.Text(msg.content)
)
case _ =>
val toolCalls =
val agentCalls =
msg match {
case msg: Message.AssistantMessage =>
Some(
Expand All @@ -136,21 +136,21 @@ object completions {
}
val callId =
msg match {
case msg: Message.ToolMessage =>
case msg: Message.AgentMessage =>
Some(msg.callId.id)
case _ =>
None
}
MessageEntry(msg.role.name, Some(msg.content), toolCalls, callId)
MessageEntry(msg.role.name, Some(msg.content), agentCalls, callId)
}
}

object Request {
def apply(
ctx: Context,
config: Config,
tools: Set[Tool[_, _]],
constrain: Option[Tool[_, _]]
agents: Set[Agent],
constrain: Option[Agent]
): Request = {
val reminder =
ctx.reminder.map(r =>
Expand All @@ -165,17 +165,25 @@ object completions {
val entries =
(reminder ++ ctx.messages ++ ctx.seed.map(s => Message.SystemMessage(s)))
.map(toEntry).reverse
val toolDefs =
if (tools.isEmpty) None
else Some(tools.map(p => ToolDef(FunctionDef(p.description, p.name, p.schema))).toList)
val agentDefs =
if (agents.isEmpty)
None
else
Some(agents.map(p =>
ToolDef(FunctionDef(
p.info.description,
p.info.name,
p.info.schema
))
).toList)
Request(
config.model.name,
config.temperature,
config.maxTokens,
config.seed,
entries,
toolDefs,
constrain.map(p => ToolChoice(Name(p.name)))
agentDefs,
constrain.map(p => ToolChoice(Name(p.info.name)))
)
}
}
Expand Down
12 changes: 6 additions & 6 deletions kyo-llm/shared/src/main/scala/kyo/llm/contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ object contexts {
val system: Role = Role("system")
val user: Role = Role("user")
val assistant: Role = Role("assistant")
val tool: Role = Role("tool")
val agent: Role = Role("tool")
}

case class CallId(id: String)
Expand Down Expand Up @@ -43,10 +43,10 @@ object contexts {
role: Role = Role.assistant
) extends Message

case class ToolMessage(
case class AgentMessage(
callId: CallId,
content: String,
role: Role = Role.tool
role: Role = Role.agent
) extends Message

def apply(role: Role, content: String): Message =
Expand Down Expand Up @@ -79,8 +79,8 @@ object contexts {
def assistantMessage(content: String, calls: List[Call] = Nil): Context =
add(Message.AssistantMessage(content, calls))

def toolMessage(callId: CallId, content: String): Context =
add(Message.ToolMessage(callId, content))
def agentMessage(callId: CallId, content: String): Context =
add(Message.AgentMessage(callId, content))

def isEmpty: Boolean =
seed.isEmpty && messages.isEmpty
Expand Down Expand Up @@ -111,7 +111,7 @@ object contexts {
).mkString(", ")
s"\n .assistantMessage(${stringify(content)}${if (calls.isEmpty) ""
else s", List($callsStr)"})"
case Message.ToolMessage(callId, content, _) =>
case Message.AgentMessage(callId, content, _) =>
s"\n .toolMessage(CallId(\"${callId.id}\"), ${stringify(content)})"
}.mkString
s"Contexts.init$seedStr$reminderStr$messagesStr"
Expand Down
Loading

0 comments on commit 14fa543

Please sign in to comment.