This commit is contained in:
@@ -1,14 +1,30 @@
|
||||
defmodule ElixirAi.ChatRunner do
|
||||
require Logger
|
||||
use GenServer
|
||||
import ElixirAi.ChatUtils, only: [ai_tool: 1]
|
||||
alias ElixirAi.{Conversation, Message}
|
||||
alias ElixirAi.{AiTools, Conversation, Message}
|
||||
import ElixirAi.PubsubTopics
|
||||
|
||||
defp via(name), do: {:via, Horde.Registry, {ElixirAi.ChatRegistry, name}}
|
||||
|
||||
def new_user_message(name, text_content) do
|
||||
GenServer.cast(via(name), {:user_message, text_content})
|
||||
def new_user_message(name, text_content, opts \\ []) do
|
||||
tool_choice = Keyword.get(opts, :tool_choice)
|
||||
GenServer.cast(via(name), {:user_message, text_content, tool_choice})
|
||||
end
|
||||
|
||||
def set_allowed_tools(name, tool_names) when is_list(tool_names) do
|
||||
GenServer.call(via(name), {:set_allowed_tools, tool_names})
|
||||
end
|
||||
|
||||
def set_tool_choice(name, tool_choice) when tool_choice in ["auto", "none", "required"] do
|
||||
GenServer.call(via(name), {:set_tool_choice, tool_choice})
|
||||
end
|
||||
|
||||
def register_liveview_pid(name, liveview_pid) when is_pid(liveview_pid) do
|
||||
GenServer.call(via(name), {:register_liveview_pid, liveview_pid})
|
||||
end
|
||||
|
||||
def deregister_liveview_pid(name) do
|
||||
GenServer.call(via(name), :deregister_liveview_pid)
|
||||
end
|
||||
|
||||
@spec get_conversation(String.t()) :: any()
|
||||
@@ -44,13 +60,35 @@ defmodule ElixirAi.ChatRunner do
|
||||
_ -> nil
|
||||
end
|
||||
|
||||
allowed_tools =
|
||||
case Conversation.find_allowed_tools(name) do
|
||||
{:ok, tools} -> tools
|
||||
_ -> AiTools.all_tool_names()
|
||||
end
|
||||
|
||||
tool_choice =
|
||||
case Conversation.find_tool_choice(name) do
|
||||
{:ok, tc} -> tc
|
||||
_ -> "auto"
|
||||
end
|
||||
|
||||
server_tools = AiTools.build_server_tools(self(), allowed_tools)
|
||||
liveview_tools = AiTools.build_liveview_tools(self(), allowed_tools)
|
||||
|
||||
if last_message && last_message.role == :user do
|
||||
Logger.info(
|
||||
"Last message role was #{last_message.role}, requesting AI response for conversation #{name}"
|
||||
)
|
||||
|
||||
broadcast_ui(name, :recovery_restart)
|
||||
ElixirAi.ChatUtils.request_ai_response(self(), messages, tools(self(), name), provider)
|
||||
|
||||
ElixirAi.ChatUtils.request_ai_response(
|
||||
self(),
|
||||
messages,
|
||||
server_tools ++ liveview_tools,
|
||||
provider,
|
||||
tool_choice
|
||||
)
|
||||
end
|
||||
|
||||
{:ok,
|
||||
@@ -59,63 +97,19 @@ defmodule ElixirAi.ChatRunner do
|
||||
messages: messages,
|
||||
streaming_response: nil,
|
||||
pending_tool_calls: [],
|
||||
tools: tools(self(), name),
|
||||
provider: provider
|
||||
allowed_tools: allowed_tools,
|
||||
tool_choice: tool_choice,
|
||||
server_tools: server_tools,
|
||||
liveview_tools: liveview_tools,
|
||||
provider: provider,
|
||||
liveview_pid: nil,
|
||||
liveview_monitor_ref: nil
|
||||
}}
|
||||
end
|
||||
|
||||
def tools(server, name) do
|
||||
[
|
||||
ai_tool(
|
||||
name: "store_thing",
|
||||
description: "store a key value pair in memory",
|
||||
function: &ElixirAi.ToolTesting.hold_thing/1,
|
||||
parameters: ElixirAi.ToolTesting.hold_thing_params(),
|
||||
server: server
|
||||
),
|
||||
ai_tool(
|
||||
name: "read_thing",
|
||||
description: "read a key value pair that was previously stored with store_thing",
|
||||
function: &ElixirAi.ToolTesting.get_thing/1,
|
||||
parameters: ElixirAi.ToolTesting.get_thing_params(),
|
||||
server: server
|
||||
),
|
||||
ai_tool(
|
||||
name: "set_background_color",
|
||||
description:
|
||||
"set the background color of the chat interface, accepts specified tailwind colors",
|
||||
function: fn %{"color" => color} ->
|
||||
Phoenix.PubSub.broadcast(
|
||||
ElixirAi.PubSub,
|
||||
chat_topic(name),
|
||||
{:set_background_color, color}
|
||||
)
|
||||
end,
|
||||
parameters: %{
|
||||
"type" => "object",
|
||||
"properties" => %{
|
||||
"color" => %{
|
||||
"type" => "string",
|
||||
"enum" => [
|
||||
"bg-cyan-950/30",
|
||||
"bg-red-950/30",
|
||||
"bg-green-950/30",
|
||||
"bg-blue-950/30",
|
||||
"bg-yellow-950/30",
|
||||
"bg-purple-950/30",
|
||||
"bg-pink-950/30"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required" => ["color"]
|
||||
},
|
||||
server: server
|
||||
)
|
||||
]
|
||||
end
|
||||
|
||||
def handle_cast({:user_message, text_content}, state) do
|
||||
new_message = %{role: :user, content: text_content}
|
||||
def handle_cast({:user_message, text_content, tool_choice_override}, state) do
|
||||
effective_tool_choice = tool_choice_override || state.tool_choice
|
||||
new_message = %{role: :user, content: text_content, tool_choice: tool_choice_override}
|
||||
broadcast_ui(state.name, {:user_chat_message, new_message})
|
||||
store_message(state.name, new_message)
|
||||
new_state = %{state | messages: state.messages ++ [new_message]}
|
||||
@@ -123,8 +117,9 @@ defmodule ElixirAi.ChatRunner do
|
||||
ElixirAi.ChatUtils.request_ai_response(
|
||||
self(),
|
||||
new_state.messages,
|
||||
state.tools,
|
||||
state.provider
|
||||
state.server_tools ++ state.liveview_tools,
|
||||
state.provider,
|
||||
effective_tool_choice
|
||||
)
|
||||
|
||||
{:noreply, new_state}
|
||||
@@ -269,7 +264,9 @@ defmodule ElixirAi.ChatRunner do
|
||||
{failed, pending} ->
|
||||
with {:ok, decoded_args} <- Jason.decode(tool_call.arguments),
|
||||
tool when not is_nil(tool) <-
|
||||
Enum.find(state.tools, fn t -> t.name == tool_call.name end) do
|
||||
Enum.find(state.server_tools ++ state.liveview_tools, fn t ->
|
||||
t.name == tool_call.name
|
||||
end) do
|
||||
tool.run_function.(id, tool_call.id, decoded_args)
|
||||
{failed, [tool_call.id | pending]}
|
||||
else
|
||||
@@ -319,8 +316,9 @@ defmodule ElixirAi.ChatRunner do
|
||||
ElixirAi.ChatUtils.request_ai_response(
|
||||
self(),
|
||||
state.messages ++ [new_message],
|
||||
state.tools,
|
||||
state.provider
|
||||
state.server_tools ++ state.liveview_tools,
|
||||
state.provider,
|
||||
state.tool_choice
|
||||
)
|
||||
end
|
||||
|
||||
@@ -362,6 +360,46 @@ defmodule ElixirAi.ChatRunner do
|
||||
{:reply, state.streaming_response, state}
|
||||
end
|
||||
|
||||
def handle_call(:get_liveview_pid, _from, state) do
|
||||
{:reply, state.liveview_pid, state}
|
||||
end
|
||||
|
||||
def handle_call({:register_liveview_pid, liveview_pid}, _from, state) do
|
||||
# Clear any previous monitor
|
||||
if state.liveview_monitor_ref, do: Process.demonitor(state.liveview_monitor_ref, [:flush])
|
||||
ref = Process.monitor(liveview_pid)
|
||||
{:reply, :ok, %{state | liveview_pid: liveview_pid, liveview_monitor_ref: ref}}
|
||||
end
|
||||
|
||||
def handle_call(:deregister_liveview_pid, _from, state) do
|
||||
if state.liveview_monitor_ref, do: Process.demonitor(state.liveview_monitor_ref, [:flush])
|
||||
{:reply, :ok, %{state | liveview_pid: nil, liveview_monitor_ref: nil}}
|
||||
end
|
||||
|
||||
def handle_call({:set_tool_choice, tool_choice}, _from, state) do
|
||||
Conversation.update_tool_choice(state.name, tool_choice)
|
||||
{:reply, :ok, %{state | tool_choice: tool_choice}}
|
||||
end
|
||||
|
||||
def handle_call({:set_allowed_tools, tool_names}, _from, state) do
|
||||
Conversation.update_allowed_tools(state.name, tool_names)
|
||||
server_tools = AiTools.build_server_tools(self(), tool_names)
|
||||
liveview_tools = AiTools.build_liveview_tools(self(), tool_names)
|
||||
|
||||
{:reply, :ok,
|
||||
%{
|
||||
state
|
||||
| allowed_tools: tool_names,
|
||||
server_tools: server_tools,
|
||||
liveview_tools: liveview_tools
|
||||
}}
|
||||
end
|
||||
|
||||
def handle_info({:DOWN, ref, :process, _pid, _reason}, %{liveview_monitor_ref: ref} = state) do
|
||||
Logger.info("ChatRunner #{state.name}: LiveView disconnected, clearing liveview_pid")
|
||||
{:noreply, %{state | liveview_pid: nil, liveview_monitor_ref: nil}}
|
||||
end
|
||||
|
||||
defp broadcast_ui(name, msg),
|
||||
do: Phoenix.PubSub.broadcast(ElixirAi.PubSub, chat_topic(name), msg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user