persisting in postgres
This commit is contained in:
@@ -6,6 +6,7 @@ defmodule ElixirAi.Application do
|
||||
def start(_type, _args) do
|
||||
children = [
|
||||
ElixirAiWeb.Telemetry,
|
||||
ElixirAi.Repo,
|
||||
{DNSCluster, query: Application.get_env(:elixir_ai, :dns_cluster_query) || :ignore},
|
||||
{Phoenix.PubSub, name: ElixirAi.PubSub},
|
||||
ElixirAi.ToolTesting,
|
||||
|
||||
@@ -2,9 +2,11 @@ defmodule ElixirAi.ChatRunner do
|
||||
require Logger
|
||||
use GenServer
|
||||
import ElixirAi.ChatUtils
|
||||
alias ElixirAi.{Conversation, Message}
|
||||
|
||||
defp via(name), do: {:via, Registry, {ElixirAi.ChatRegistry, name}}
|
||||
defp topic(name), do: "ai_chat:#{name}"
|
||||
defp message_topic(name), do: "conversation_messages:#{name}"
|
||||
|
||||
def new_user_message(name, text_content) do
|
||||
GenServer.cast(via(name), {:user_message, text_content})
|
||||
@@ -24,13 +26,20 @@ defmodule ElixirAi.ChatRunner do
|
||||
end
|
||||
|
||||
def init(name) do
|
||||
{:ok, %{
|
||||
name: name,
|
||||
messages: [],
|
||||
streaming_response: nil,
|
||||
pending_tool_calls: [],
|
||||
tools: tools(self(), name)
|
||||
}}
|
||||
messages =
|
||||
case Conversation.find_id(name) do
|
||||
{:ok, conv_id} -> Message.load_for_conversation(conv_id)
|
||||
_ -> []
|
||||
end
|
||||
|
||||
{:ok,
|
||||
%{
|
||||
name: name,
|
||||
messages: messages,
|
||||
streaming_response: nil,
|
||||
pending_tool_calls: [],
|
||||
tools: tools(self(), name)
|
||||
}}
|
||||
end
|
||||
|
||||
def tools(server, name) do
|
||||
@@ -54,7 +63,11 @@ defmodule ElixirAi.ChatRunner do
|
||||
description:
|
||||
"set the background color of the chat interface, accepts specified tailwind colors",
|
||||
function: fn %{"color" => color} ->
|
||||
Phoenix.PubSub.broadcast(ElixirAi.PubSub, "ai_chat:#{name}", {:set_background_color, color})
|
||||
Phoenix.PubSub.broadcast(
|
||||
ElixirAi.PubSub,
|
||||
"ai_chat:#{name}",
|
||||
{:set_background_color, color}
|
||||
)
|
||||
end,
|
||||
parameters: %{
|
||||
"type" => "object",
|
||||
@@ -73,7 +86,8 @@ defmodule ElixirAi.ChatRunner do
|
||||
|
||||
def handle_cast({:user_message, text_content}, state) do
|
||||
new_message = %{role: :user, content: text_content}
|
||||
broadcast(state.name, {:user_chat_message, new_message})
|
||||
broadcast_ui(state.name, {:user_chat_message, new_message})
|
||||
store_message(state.name, new_message)
|
||||
new_state = %{state | messages: state.messages ++ [new_message]}
|
||||
|
||||
request_ai_response(self(), new_state.messages, state.tools)
|
||||
@@ -82,7 +96,7 @@ defmodule ElixirAi.ChatRunner do
|
||||
|
||||
def handle_info({:start_new_ai_response, id}, state) do
|
||||
starting_response = %{id: id, reasoning_content: "", content: "", tool_calls: []}
|
||||
broadcast(state.name, {:start_ai_response_stream, starting_response})
|
||||
broadcast_ui(state.name, {:start_ai_response_stream, starting_response})
|
||||
|
||||
{:noreply, %{state | streaming_response: starting_response}}
|
||||
end
|
||||
@@ -100,7 +114,7 @@ defmodule ElixirAi.ChatRunner do
|
||||
end
|
||||
|
||||
def handle_info({:ai_reasoning_chunk, _id, reasoning_content}, state) do
|
||||
broadcast(state.name, {:reasoning_chunk_content, reasoning_content})
|
||||
broadcast_ui(state.name, {:reasoning_chunk_content, reasoning_content})
|
||||
|
||||
{:noreply,
|
||||
%{
|
||||
@@ -113,7 +127,7 @@ defmodule ElixirAi.ChatRunner do
|
||||
end
|
||||
|
||||
def handle_info({:ai_text_chunk, _id, text_content}, state) do
|
||||
broadcast(state.name, {:text_chunk_content, text_content})
|
||||
broadcast_ui(state.name, {:text_chunk_content, text_content})
|
||||
|
||||
{:noreply,
|
||||
%{
|
||||
@@ -137,7 +151,8 @@ defmodule ElixirAi.ChatRunner do
|
||||
tool_calls: state.streaming_response.tool_calls
|
||||
}
|
||||
|
||||
broadcast(state.name, {:end_ai_response, final_message})
|
||||
broadcast_ui(state.name, {:end_ai_response, final_message})
|
||||
store_message(state.name, final_message)
|
||||
|
||||
{:noreply,
|
||||
%{
|
||||
@@ -202,13 +217,14 @@ defmodule ElixirAi.ChatRunner do
|
||||
tool_calls: state.streaming_response.tool_calls
|
||||
}
|
||||
|
||||
broadcast(state.name, {:tool_request_message, tool_request_message})
|
||||
|
||||
broadcast_ui(state.name, {:tool_request_message, tool_request_message})
|
||||
|
||||
{failed_call_messages, pending_call_ids} =
|
||||
Enum.reduce(state.streaming_response.tool_calls, {[], []}, fn tool_call, {failed, pending} ->
|
||||
Enum.reduce(state.streaming_response.tool_calls, {[], []}, fn tool_call,
|
||||
{failed, pending} ->
|
||||
with {:ok, decoded_args} <- Jason.decode(tool_call.arguments),
|
||||
tool when not is_nil(tool) <- Enum.find(state.tools, fn t -> t.name == tool_call.name end) do
|
||||
tool when not is_nil(tool) <-
|
||||
Enum.find(state.tools, fn t -> t.name == tool_call.name end) do
|
||||
tool.run_function.(id, tool_call.id, decoded_args)
|
||||
{failed, [tool_call.id | pending]}
|
||||
else
|
||||
@@ -224,6 +240,8 @@ defmodule ElixirAi.ChatRunner do
|
||||
end
|
||||
end)
|
||||
|
||||
store_message(state.name, [tool_request_message] ++ failed_call_messages)
|
||||
|
||||
{:noreply,
|
||||
%{
|
||||
state
|
||||
@@ -235,7 +253,8 @@ defmodule ElixirAi.ChatRunner do
|
||||
def handle_info({:tool_response, _id, tool_call_id, result}, state) do
|
||||
new_message = %{role: :tool, content: inspect(result), tool_call_id: tool_call_id}
|
||||
|
||||
broadcast(state.name, {:one_tool_finished, new_message})
|
||||
broadcast_ui(state.name, {:one_tool_finished, new_message})
|
||||
store_message(state.name, new_message)
|
||||
|
||||
new_pending_tool_calls =
|
||||
Enum.filter(state.pending_tool_calls, fn id -> id != tool_call_id end)
|
||||
@@ -250,7 +269,7 @@ defmodule ElixirAi.ChatRunner do
|
||||
end
|
||||
|
||||
if new_pending_tool_calls == [] do
|
||||
broadcast(state.name, :tool_calls_finished)
|
||||
broadcast_ui(state.name, :tool_calls_finished)
|
||||
request_ai_response(self(), state.messages ++ [new_message], state.tools)
|
||||
end
|
||||
|
||||
@@ -271,5 +290,20 @@ defmodule ElixirAi.ChatRunner do
|
||||
{:reply, state.streaming_response, state}
|
||||
end
|
||||
|
||||
defp broadcast(name, msg), do: Phoenix.PubSub.broadcast(ElixirAi.PubSub, topic(name), msg)
|
||||
defp broadcast_ui(name, msg), do: Phoenix.PubSub.broadcast(ElixirAi.PubSub, topic(name), msg)
|
||||
|
||||
defp store_message(name, messages) when is_list(messages) do
|
||||
Enum.each(messages, &store_message(name, &1))
|
||||
messages
|
||||
end
|
||||
|
||||
defp store_message(name, message) do
|
||||
Phoenix.PubSub.broadcast(
|
||||
ElixirAi.PubSub,
|
||||
message_topic(name),
|
||||
{:store_message, name, message}
|
||||
)
|
||||
|
||||
message
|
||||
end
|
||||
end
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
defmodule ElixirAi.ConversationManager do
|
||||
use GenServer
|
||||
alias ElixirAi.{Conversation, Message}
|
||||
|
||||
def start_link(_opts), do: GenServer.start_link(__MODULE__, [], name: __MODULE__)
|
||||
def init(names), do: {:ok, names}
|
||||
def start_link(_opts), do: GenServer.start_link(__MODULE__, nil, name: __MODULE__)
|
||||
|
||||
def init(_) do
|
||||
names = Conversation.all_names()
|
||||
conversations = Map.new(names, fn name -> {name, []} end)
|
||||
{:ok, conversations}
|
||||
end
|
||||
|
||||
def create_conversation(name) do
|
||||
GenServer.call(__MODULE__, {:create, name})
|
||||
@@ -16,34 +22,76 @@ defmodule ElixirAi.ConversationManager do
|
||||
GenServer.call(__MODULE__, :list)
|
||||
end
|
||||
|
||||
def handle_call({:create, name}, _from, names) do
|
||||
if name in names do
|
||||
{:reply, {:error, :already_exists}, names}
|
||||
def get_messages(name) do
|
||||
GenServer.call(__MODULE__, {:get_messages, name})
|
||||
end
|
||||
|
||||
def handle_call({:create, name}, _from, conversations) do
|
||||
if Map.has_key?(conversations, name) do
|
||||
{:reply, {:error, :already_exists}, conversations}
|
||||
else
|
||||
{:reply, start_runner(name), [name | names]}
|
||||
case Conversation.create(name) do
|
||||
:ok ->
|
||||
case start_and_subscribe(name) do
|
||||
{:ok, _pid} = ok -> {:reply, ok, Map.put(conversations, name, [])}
|
||||
error -> {:reply, error, conversations}
|
||||
end
|
||||
|
||||
{:error, _} = error ->
|
||||
{:reply, error, conversations}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def handle_call({:open, name}, _from, names) do
|
||||
if name in names do
|
||||
{:reply, start_runner(name), names}
|
||||
def handle_call({:open, name}, _from, conversations) do
|
||||
if Map.has_key?(conversations, name) do
|
||||
case start_and_subscribe(name) do
|
||||
{:ok, _pid} = ok -> {:reply, ok, conversations}
|
||||
error -> {:reply, error, conversations}
|
||||
end
|
||||
else
|
||||
{:reply, {:error, :not_found}, names}
|
||||
end
|
||||
end
|
||||
def handle_call(:list, _from, names) do
|
||||
{:reply, names, names}
|
||||
end
|
||||
|
||||
defp start_runner(name) do
|
||||
case DynamicSupervisor.start_child(
|
||||
ElixirAi.ChatRunnerSupervisor,
|
||||
{ElixirAi.ChatRunner, name: name}
|
||||
) do
|
||||
{:ok, pid} -> {:ok, pid}
|
||||
{:error, {:already_started, pid}} -> {:ok, pid}
|
||||
error -> error
|
||||
{:reply, {:error, :not_found}, conversations}
|
||||
end
|
||||
end
|
||||
|
||||
def handle_call(:list, _from, conversations) do
|
||||
{:reply, Map.keys(conversations), conversations}
|
||||
end
|
||||
|
||||
def handle_call({:get_messages, name}, _from, conversations) do
|
||||
{:reply, Map.get(conversations, name, []), conversations}
|
||||
end
|
||||
|
||||
def handle_info({:store_message, name, message}, conversations) do
|
||||
messages = Map.get(conversations, name, [])
|
||||
position = length(messages)
|
||||
|
||||
case Conversation.find_id(name) do
|
||||
{:ok, conv_id} -> Message.insert(conv_id, message, position)
|
||||
_ -> :ok
|
||||
end
|
||||
|
||||
{:noreply, Map.update(conversations, name, [message], &(&1 ++ [message]))}
|
||||
end
|
||||
|
||||
defp start_and_subscribe(name) do
|
||||
result =
|
||||
case DynamicSupervisor.start_child(
|
||||
ElixirAi.ChatRunnerSupervisor,
|
||||
{ElixirAi.ChatRunner, name: name}
|
||||
) do
|
||||
{:ok, pid} -> {:ok, pid}
|
||||
{:error, {:already_started, pid}} -> {:ok, pid}
|
||||
error -> error
|
||||
end
|
||||
|
||||
case result do
|
||||
{:ok, _pid} ->
|
||||
Phoenix.PubSub.subscribe(ElixirAi.PubSub, "conversation_messages:#{name}")
|
||||
result
|
||||
|
||||
_ ->
|
||||
result
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
26
lib/elixir_ai/data/conversation.ex
Normal file
26
lib/elixir_ai/data/conversation.ex
Normal file
@@ -0,0 +1,26 @@
|
||||
defmodule ElixirAi.Conversation do
|
||||
import Ecto.Query
|
||||
alias ElixirAi.Repo
|
||||
|
||||
def all_names do
|
||||
Repo.all(from c in "conversations", select: c.name)
|
||||
end
|
||||
|
||||
def create(name) do
|
||||
case Repo.insert_all("conversations", [[id: Ecto.UUID.generate(), name: name, inserted_at: now(), updated_at: now()]]) do
|
||||
{1, _} -> :ok
|
||||
_ -> {:error, :db_error}
|
||||
end
|
||||
rescue
|
||||
e in Ecto.ConstraintError -> if e.constraint == "conversations_name_index", do: {:error, :already_exists}, else: {:error, :db_error}
|
||||
end
|
||||
|
||||
def find_id(name) do
|
||||
case Repo.one(from c in "conversations", where: c.name == ^name, select: c.id) do
|
||||
nil -> {:error, :not_found}
|
||||
id -> {:ok, id}
|
||||
end
|
||||
end
|
||||
|
||||
defp now, do: DateTime.truncate(DateTime.utc_now(), :second)
|
||||
end
|
||||
49
lib/elixir_ai/data/message.ex
Normal file
49
lib/elixir_ai/data/message.ex
Normal file
@@ -0,0 +1,49 @@
|
||||
defmodule ElixirAi.Message do
|
||||
import Ecto.Query
|
||||
alias ElixirAi.Repo
|
||||
|
||||
def load_for_conversation(conversation_id) do
|
||||
Repo.all(
|
||||
from m in "messages",
|
||||
where: m.conversation_id == ^conversation_id,
|
||||
order_by: m.position,
|
||||
select: %{
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
reasoning_content: m.reasoning_content,
|
||||
tool_calls: m.tool_calls,
|
||||
tool_call_id: m.tool_call_id
|
||||
}
|
||||
)
|
||||
|> Enum.map(&decode_message/1)
|
||||
end
|
||||
|
||||
def insert(conversation_id, message, position) do
|
||||
Repo.insert_all("messages", [
|
||||
[
|
||||
id: Ecto.UUID.generate(),
|
||||
conversation_id: conversation_id,
|
||||
role: to_string(message.role),
|
||||
content: message[:content],
|
||||
reasoning_content: message[:reasoning_content],
|
||||
tool_calls: encode_tool_calls(message[:tool_calls]),
|
||||
tool_call_id: message[:tool_call_id],
|
||||
position: position,
|
||||
inserted_at: DateTime.truncate(DateTime.utc_now(), :second)
|
||||
]
|
||||
])
|
||||
end
|
||||
|
||||
defp encode_tool_calls(nil), do: nil
|
||||
defp encode_tool_calls(calls), do: Jason.encode!(calls)
|
||||
|
||||
defp decode_message(row) do
|
||||
row
|
||||
|> Map.update!(:role, &String.to_existing_atom/1)
|
||||
|> drop_nil_fields()
|
||||
end
|
||||
|
||||
defp drop_nil_fields(map) do
|
||||
Map.reject(map, fn {_k, v} -> is_nil(v) end)
|
||||
end
|
||||
end
|
||||
5
lib/elixir_ai/data/repo.ex
Normal file
5
lib/elixir_ai/data/repo.ex
Normal file
@@ -0,0 +1,5 @@
|
||||
defmodule ElixirAi.Repo do
|
||||
use Ecto.Repo,
|
||||
otp_app: :elixir_ai,
|
||||
adapter: Ecto.Adapters.Postgres
|
||||
end
|
||||
Reference in New Issue
Block a user