-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Persist AI Assistant conversations and enable it for all users (#2453)
* placeholder changelog * Persist AI messages (#2427) * persist messages * use role instead of sender in chat_messages. Simplifies things * allow all users to access the ai assistant * add test for viewing form * make dialyzer happy * make dialyzer happy * Resume Chat Sessions (#2439) * WIP * users can resume sessions * redirect to the created session for new chat sessions * make credo happy * save session title * fix bug when following run * Polish Chat Assistant UI (#2452) * polish UI * show user avatars in the chat session * disable submit when user is not allowed to edit workflow * Handle errors when saving message and querying apollo (#2456) this also gives way for handling limiter errors * Tests For Ai Assistant (#2458) * add tests for assistant * add test for failures * add ability to clear error message * increase sleep duration for async task in test * update chagelog * fix bug where closing edit modal did not update the onboarding ui * Get rid of Process.sleep in test instead use PubSub * Limit AI queries (#2457) * Limit AI queries * Remove forced error and add test case * Increment ai chat messages * Changelog and formatting * Handles assistant role as string sent by async process_message * Fix banner * Create index for counting * Use extension to increment ai queries * Simplify test and check expected limiter extension calls * Fix rebase * Formatting * Increment on reply * Changelog * Icon and Center of banner The Common.banner needs a fix * move limiter banner to the component * fix failing tests --------- Co-authored-by: Frank Midigo <[email protected]> * always show the limit banner --------- Co-authored-by: Rogerio Pontual <[email protected]>
- Loading branch information
1 parent
97d44fb
commit 4ba12e9
Showing
17 changed files
with
1,593 additions
and
306 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,103 +3,89 @@ defmodule Lightning.AiAssistant do | |
The AI assistant module. | ||
""" | ||
|
||
import Ecto.Query | ||
|
||
alias Ecto.Multi | ||
alias Lightning.Accounts.User | ||
alias Lightning.AiAssistant.ChatSession | ||
alias Lightning.ApolloClient | ||
alias Lightning.Repo | ||
alias Lightning.Services.UsageLimiter | ||
alias Lightning.Workflows.Job | ||
|
||
defmodule Session do | ||
@moduledoc """ | ||
Represents a session with the AI assistant. | ||
""" | ||
|
||
defstruct [ | ||
:id, | ||
:expression, | ||
:adaptor, | ||
:history | ||
] | ||
|
||
@type t() :: %__MODULE__{ | ||
id: Ecto.UUID.t(), | ||
expression: String.t(), | ||
adaptor: String.t(), | ||
history: history() | ||
} | ||
|
||
@type history() :: [ | ||
%{role: :user | :assistant, content: String.t()} | ||
] | ||
|
||
@spec new(Job.t()) :: t() | ||
def new(job) do | ||
%Session{ | ||
id: job.id, | ||
expression: job.body, | ||
adaptor: Lightning.AdaptorRegistry.resolve_adaptor(job.adaptor), | ||
history: [] | ||
} | ||
end | ||
|
||
@spec put_history(t(), history() | [%{String.t() => any()}]) :: t() | ||
def put_history(session, history) do | ||
history = | ||
Enum.map(history, fn h -> | ||
%{role: h["role"] || h[:role], content: h["content"] || h[:content]} | ||
end) | ||
|
||
%{session | history: history} | ||
end | ||
|
||
@spec push_history(t(), %{String.t() => any()}) :: t() | ||
def push_history(session, message) do | ||
history = | ||
session.history ++ | ||
[ | ||
%{ | ||
role: message["role"] || message[:role], | ||
content: message["content"] || message[:content] | ||
} | ||
] | ||
|
||
%{session | history: history} | ||
end | ||
@spec put_expression_and_adaptor(ChatSession.t(), String.t(), String.t()) :: | ||
ChatSession.t() | ||
def put_expression_and_adaptor(session, expression, adaptor) do | ||
%{ | ||
session | ||
| expression: expression, | ||
adaptor: Lightning.AdaptorRegistry.resolve_adaptor(adaptor) | ||
} | ||
end | ||
|
||
@doc """ | ||
Puts the given expression into the session. | ||
""" | ||
@spec put_expression(t(), String.t()) :: t() | ||
def put_expression(session, expression) do | ||
%{session | expression: expression} | ||
end | ||
@spec list_sessions_for_job(Job.t()) :: [ChatSession.t(), ...] | [] | ||
def list_sessions_for_job(job) do | ||
Repo.all( | ||
from s in ChatSession, | ||
where: s.job_id == ^job.id, | ||
order_by: [desc: :updated_at], | ||
preload: [:user] | ||
) | ||
end | ||
|
||
@doc """ | ||
Creates a new session with the given job. | ||
@spec get_session!(Ecto.UUID.t()) :: ChatSession.t() | ||
def get_session!(id) do | ||
ChatSession |> Repo.get!(id) |> Repo.preload(messages: :user) | ||
end | ||
|
||
**Example** | ||
@spec create_session(Job.t(), User.t(), String.t()) :: | ||
{:ok, ChatSession.t()} | {:error, Ecto.Changeset.t()} | ||
def create_session(job, user, content) do | ||
%ChatSession{ | ||
id: Ecto.UUID.generate(), | ||
job_id: job.id, | ||
user_id: user.id, | ||
title: String.slice(content, 0, 40), | ||
messages: [] | ||
} | ||
|> put_expression_and_adaptor(job.body, job.adaptor) | ||
|> save_message(%{role: :user, content: content, user: user}) | ||
end | ||
|
||
iex> AiAssistant.new_session(%Lightning.Workflows.Job{ | ||
...> body: "fn()", | ||
...> adaptor: "@openfn/language-common@latest" | ||
...> }) | ||
%Lightning.AiAssistant.Session{ | ||
expression: "fn()", | ||
adaptor: "@openfn/[email protected]", | ||
history: [] | ||
} | ||
> ℹ️ The `adaptor` field is resolved to the latest version when `@latest` | ||
> is provided as Apollo expects a specific version. | ||
""" | ||
@spec save_message(ChatSession.t(), %{any() => any()}) :: | ||
{:ok, ChatSession.t()} | {:error, Ecto.Changeset.t()} | ||
def save_message(session, message) do | ||
# we can call the limiter at this point | ||
# note: we should only increment the counter when role is `:assistant` | ||
messages = Enum.map(session.messages, &Map.take(&1, [:id])) | ||
|
||
Multi.new() | ||
|> Multi.put(:message, message) | ||
|> Multi.insert_or_update( | ||
:upsert, | ||
session | ||
|> ChatSession.changeset(%{messages: messages ++ [message]}) | ||
) | ||
|> Multi.merge(&maybe_increment_msgs_counter/1) | ||
|> Repo.transaction() | ||
|> case do | ||
{:ok, %{upsert: session}} -> | ||
{:ok, session} | ||
|
||
@spec new_session(Job.t()) :: Session.t() | ||
def new_session(job) do | ||
Session.new(job) | ||
{:error, _operation, changeset, _changes} -> | ||
{:error, changeset} | ||
end | ||
end | ||
|
||
@spec push_history(Session.t(), %{String.t() => any()}) :: Session.t() | ||
def push_history(session, message) do | ||
Session.push_history(session, message) | ||
@spec project_has_any_session?(Ecto.UUID.t()) :: boolean() | ||
def project_has_any_session?(project_id) do | ||
query = | ||
from s in ChatSession, | ||
join: j in assoc(s, :job), | ||
join: w in assoc(j, :workflow), | ||
where: w.project_id == ^project_id | ||
|
||
Repo.exists?(query) | ||
end | ||
|
||
@doc """ | ||
|
@@ -112,28 +98,48 @@ defmodule Lightning.AiAssistant do | |
iex> AiAssistant.query(session, "fn()") | ||
{:ok, session} | ||
""" | ||
@spec query(Session.t(), String.t()) :: {:ok, Session.t()} | :error | ||
@spec query(ChatSession.t(), String.t()) :: | ||
{:ok, ChatSession.t()} | ||
| {:error, Ecto.Changeset.t() | :apollo_unavailable} | ||
def query(session, content) do | ||
ApolloClient.query( | ||
content, | ||
%{expression: session.expression, adaptor: session.adaptor}, | ||
session.history | ||
) | ||
|> case do | ||
{:ok, %Tesla.Env{status: status} = response} when status in 200..299 -> | ||
{:ok, session |> Session.put_history(response.body["history"])} | ||
apollo_resp = | ||
ApolloClient.query( | ||
content, | ||
%{expression: session.expression, adaptor: session.adaptor}, | ||
build_history(session) | ||
) | ||
|
||
case apollo_resp do | ||
{:ok, %Tesla.Env{status: status, body: body}} when status in 200..299 -> | ||
message = body["history"] |> Enum.reverse() |> hd() | ||
save_message(session, message) | ||
|
||
_ -> | ||
:error | ||
{:error, :apollo_unavailable} | ||
end | ||
end | ||
|
||
defp build_history(session) do | ||
case Enum.reverse(session.messages) do | ||
[%{role: :user} | other] -> | ||
other | ||
|> Enum.reverse() | ||
|> Enum.map(&Map.take(&1, [:role, :content])) | ||
|
||
messages -> | ||
Enum.map(messages, &Map.take(&1, [:role, :content])) | ||
end | ||
end | ||
|
||
@doc """ | ||
Checks if the user is authorized to access the AI assistant. | ||
Checks if the AI assistant is enabled. | ||
""" | ||
@spec authorized?(User.t()) :: boolean() | ||
def authorized?(user) do | ||
user.role == :superuser | ||
@spec enabled?() :: boolean() | ||
def enabled? do | ||
endpoint = Lightning.Config.apollo(:endpoint) | ||
api_key = Lightning.Config.apollo(:openai_api_key) | ||
|
||
is_binary(endpoint) && is_binary(api_key) | ||
end | ||
|
||
@doc """ | ||
|
@@ -143,4 +149,23 @@ defmodule Lightning.AiAssistant do | |
def endpoint_available? do | ||
ApolloClient.test() == :ok | ||
end | ||
|
||
# assistant role sent via async as string | ||
defp maybe_increment_msgs_counter(%{ | ||
upsert: session, | ||
message: %{"role" => "assistant"} | ||
}), | ||
do: | ||
maybe_increment_msgs_counter(%{ | ||
upsert: session, | ||
message: %{role: :assistant} | ||
}) | ||
|
||
defp maybe_increment_msgs_counter(%{ | ||
upsert: session, | ||
message: %{role: :assistant} | ||
}), | ||
do: UsageLimiter.increment_ai_queries(session) | ||
|
||
defp maybe_increment_msgs_counter(_user_role), do: Multi.new() | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
defmodule Lightning.AiAssistant.ChatMessage do | ||
@moduledoc false | ||
|
||
use Lightning.Schema | ||
import Ecto.Changeset | ||
import Lightning.Validators, only: [validate_required_assoc: 2] | ||
|
||
@type role() :: :user | :assistant | ||
@type t() :: %__MODULE__{ | ||
id: Ecto.UUID.t(), | ||
content: String.t() | nil, | ||
role: role(), | ||
is_deleted: boolean(), | ||
is_public: boolean() | ||
} | ||
|
||
schema "ai_chat_messages" do | ||
field :content, :string | ||
field :role, Ecto.Enum, values: [:user, :assistant] | ||
field :is_deleted, :boolean, default: false | ||
field :is_public, :boolean, default: true | ||
|
||
belongs_to :chat_session, Lightning.AiAssistant.ChatSession | ||
belongs_to :user, Lightning.Accounts.User | ||
|
||
timestamps() | ||
end | ||
|
||
def changeset(chat_message, attrs) do | ||
chat_message | ||
|> cast(attrs, [ | ||
:content, | ||
:role, | ||
:is_deleted, | ||
:is_public, | ||
:chat_session_id | ||
]) | ||
|> validate_required([:content, :role]) | ||
|> maybe_put_user_assoc(attrs[:user] || attrs["user"]) | ||
|> maybe_require_user() | ||
end | ||
|
||
defp maybe_put_user_assoc(changeset, user) do | ||
if user do | ||
put_assoc(changeset, :user, user) | ||
else | ||
changeset | ||
end | ||
end | ||
|
||
defp maybe_require_user(changeset) do | ||
if get_field(changeset, :role) == :user do | ||
validate_required_assoc(changeset, :user) | ||
else | ||
changeset | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
defmodule Lightning.AiAssistant.ChatSession do | ||
@moduledoc false | ||
|
||
use Lightning.Schema | ||
import Ecto.Changeset | ||
|
||
alias Lightning.Accounts.User | ||
alias Lightning.AiAssistant.ChatMessage | ||
alias Lightning.Workflows.Job | ||
|
||
@type t() :: %__MODULE__{ | ||
id: Ecto.UUID.t(), | ||
job_id: Ecto.UUID.t(), | ||
user_id: Ecto.UUID.t(), | ||
title: String.t(), | ||
expression: String.t() | nil, | ||
adaptor: String.t() | nil, | ||
is_public: boolean(), | ||
is_deleted: boolean(), | ||
messages: [ChatMessage.t(), ...] | [] | ||
} | ||
|
||
schema "ai_chat_sessions" do | ||
field :expression, :string, virtual: true | ||
field :adaptor, :string, virtual: true | ||
field :title, :string | ||
field :is_public, :boolean, default: false | ||
field :is_deleted, :boolean, default: false | ||
belongs_to :job, Job | ||
belongs_to :user, User | ||
|
||
has_many :messages, ChatMessage, preload_order: [asc: :inserted_at] | ||
|
||
timestamps() | ||
end | ||
|
||
def changeset(chat_session, attrs) do | ||
chat_session | ||
|> cast(attrs, [:title, :is_public, :is_deleted, :job_id, :user_id]) | ||
|> validate_required([:title, :job_id, :user_id]) | ||
|> cast_assoc(:messages) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
defmodule Lightning.AiAssistant.Limiter do | ||
@moduledoc """ | ||
The AI assistant limiter to check for AI query quota. | ||
""" | ||
|
||
alias Lightning.Extensions.UsageLimiting | ||
alias Lightning.Extensions.UsageLimiting.Action | ||
alias Lightning.Extensions.UsageLimiting.Context | ||
alias Lightning.Services.UsageLimiter | ||
|
||
@doc """ | ||
Checks if has not reached the limit of the project ai queries quota. | ||
""" | ||
@spec validate_quota(Ecto.UUID.t()) :: :ok | UsageLimiting.error() | ||
def validate_quota(project_id) do | ||
UsageLimiter.limit_action(%Action{type: :ai_query}, %Context{ | ||
project_id: project_id | ||
}) | ||
end | ||
end |
Oops, something went wrong.