diff --git a/lib/elixir_ai/ai_tools.ex b/lib/elixir_ai/ai_tools.ex index e89822a..4e0bbc5 100644 --- a/lib/elixir_ai/ai_tools.ex +++ b/lib/elixir_ai/ai_tools.ex @@ -131,7 +131,7 @@ defmodule ElixirAi.AiTools do # --------------------------------------------------------------------------- defp dispatch_to_liveview(server, tool_name, args) do - pids = GenServer.call(server, :get_liveview_pids) + pids = GenServer.call(server, {:session, :get_liveview_pids}) case pids do [] -> diff --git a/lib/elixir_ai/ai_utils/chat_utils.ex b/lib/elixir_ai/ai_utils/chat_utils.ex index 45697e8..a5486b6 100644 --- a/lib/elixir_ai/ai_utils/chat_utils.ex +++ b/lib/elixir_ai/ai_utils/chat_utils.ex @@ -30,12 +30,16 @@ defmodule ElixirAi.ChatUtils do Task.start_link(fn -> try do result = function.(args) - send(server, {:tool_response, current_message_id, tool_call_id, result}) + send(server, {:stream, {:tool_response, current_message_id, tool_call_id, result}}) rescue e -> reason = Exception.format(:error, e, __STACKTRACE__) Logger.error("Tool task crashed: #{reason}") - send(server, {:tool_response, current_message_id, tool_call_id, {:error, reason}}) + + send( + server, + {:stream, {:tool_response, current_message_id, tool_call_id, {:error, reason}}} + ) end end) end @@ -91,7 +95,7 @@ defmodule ElixirAi.ChatUtils do {:error, reason} -> Logger.warning("AI request failed: #{inspect(reason)} for #{api_url}") - send(server, {:ai_request_error, reason}) + send(server, {:stream, {:ai_request_error, reason}}) end end) end diff --git a/lib/elixir_ai/ai_utils/stream_line_utils.ex b/lib/elixir_ai/ai_utils/stream_line_utils.ex index 1e87830..d5e8aec 100644 --- a/lib/elixir_ai/ai_utils/stream_line_utils.ex +++ b/lib/elixir_ai/ai_utils/stream_line_utils.ex @@ -29,7 +29,7 @@ defmodule ElixirAi.AiUtils.StreamLineUtils do }) do send( server, - {:start_new_ai_response, id} + {:stream, {:start_new_ai_response, id}} ) end @@ -45,7 +45,7 @@ defmodule ElixirAi.AiUtils.StreamLineUtils do send( server, - {:ai_text_stream_finish, id} + {:stream, {:ai_text_stream_finish, id}} ) end @@ -61,7 +61,7 @@ defmodule ElixirAi.AiUtils.StreamLineUtils do }) do send( server, - {:ai_reasoning_chunk, id, reasoning_content} + {:stream, {:ai_reasoning_chunk, id, reasoning_content}} ) end @@ -77,7 +77,7 @@ defmodule ElixirAi.AiUtils.StreamLineUtils do }) do send( server, - {:ai_text_chunk, id, reasoning_content} + {:stream, {:ai_text_chunk, id, reasoning_content}} ) end @@ -105,12 +105,13 @@ defmodule ElixirAi.AiUtils.StreamLineUtils do send( server, - {:ai_tool_call_start, id, {tool_name, tool_args_start, tool_index, tool_call_id}} + {:stream, + {:ai_tool_call_start, id, {tool_name, tool_args_start, tool_index, tool_call_id}}} ) %{"index" => tool_index, "function" => %{"arguments" => tool_args_diff}} -> # Logger.info("Received tool call middle for index #{tool_index}") - send(server, {:ai_tool_call_middle, id, {tool_args_diff, tool_index}}) + send(server, {:stream, {:ai_tool_call_middle, id, {tool_args_diff, tool_index}}}) other -> Logger.warning("Unmatched tool call item: #{inspect(other)}") @@ -126,7 +127,7 @@ defmodule ElixirAi.AiUtils.StreamLineUtils do } ) do # Logger.info("Received tool_calls_finished with message: #{inspect(message)}") - send(server, {:ai_tool_call_end, id}) + send(server, {:stream, {:ai_tool_call_end, id}}) end def handle_stream_line(_server, %{"error" => error_info}) do diff --git a/lib/elixir_ai/chat_runner.ex b/lib/elixir_ai/chat_runner.ex deleted file mode 100644 index 47215e1..0000000 --- a/lib/elixir_ai/chat_runner.ex +++ /dev/null @@ -1,429 +0,0 @@ -defmodule ElixirAi.ChatRunner do - require Logger - use GenServer - alias ElixirAi.{AiTools, Conversation, Message} - import ElixirAi.PubsubTopics - - defp via(name), do: {:via, Horde.Registry, {ElixirAi.ChatRegistry, name}} - - def new_user_message(name, text_content, opts \\ []) do - tool_choice = Keyword.get(opts, :tool_choice) - GenServer.cast(via(name), {:user_message, text_content, tool_choice}) - end - - def set_allowed_tools(name, tool_names) when is_list(tool_names) do - GenServer.call(via(name), {:set_allowed_tools, tool_names}) - end - - def set_tool_choice(name, tool_choice) when tool_choice in ["auto", "none", "required"] do - GenServer.call(via(name), {:set_tool_choice, tool_choice}) - end - - def register_liveview_pid(name, liveview_pid) when is_pid(liveview_pid) do - GenServer.call(via(name), {:register_liveview_pid, liveview_pid}) - end - - def deregister_liveview_pid(name, liveview_pid) when is_pid(liveview_pid) do - GenServer.call(via(name), {:deregister_liveview_pid, liveview_pid}) - end - - @spec get_conversation(String.t()) :: any() - def get_conversation(name) do - GenServer.call(via(name), :get_conversation) - end - - def get_streaming_response(name) do - GenServer.call(via(name), :get_streaming_response) - end - - def start_link(name: name) do - GenServer.start_link(__MODULE__, name, name: via(name)) - end - - def init(name) do - Phoenix.PubSub.subscribe(ElixirAi.PubSub, conversation_message_topic(name)) - - messages = - case Conversation.find_id(name) do - {:ok, conv_id} -> - Message.load_for_conversation(conv_id, topic: conversation_message_topic(name)) - - _ -> - [] - end - - last_message = List.last(messages) - - provider = - case Conversation.find_provider(name) do - {:ok, p} -> p - _ -> nil - end - - allowed_tools = - case Conversation.find_allowed_tools(name) do - {:ok, tools} -> tools - _ -> AiTools.all_tool_names() - end - - tool_choice = - case Conversation.find_tool_choice(name) do - {:ok, tc} -> tc - _ -> "auto" - end - - server_tools = AiTools.build_server_tools(self(), allowed_tools) - liveview_tools = AiTools.build_liveview_tools(self(), allowed_tools) - - if last_message && last_message.role == :user do - Logger.info( - "Last message role was #{last_message.role}, requesting AI response for conversation #{name}" - ) - - broadcast_ui(name, :recovery_restart) - - ElixirAi.ChatUtils.request_ai_response( - self(), - messages, - server_tools ++ liveview_tools, - provider, - tool_choice - ) - end - - {:ok, - %{ - name: name, - messages: messages, - streaming_response: nil, - pending_tool_calls: [], - allowed_tools: allowed_tools, - tool_choice: tool_choice, - server_tools: server_tools, - liveview_tools: liveview_tools, - provider: provider, - liveview_pids: %{} - }} - end - - def handle_cast({:user_message, text_content, tool_choice_override}, state) do - effective_tool_choice = tool_choice_override || state.tool_choice - new_message = %{role: :user, content: text_content, tool_choice: tool_choice_override} - broadcast_ui(state.name, {:user_chat_message, new_message}) - store_message(state.name, new_message) - new_state = %{state | messages: state.messages ++ [new_message]} - - ElixirAi.ChatUtils.request_ai_response( - self(), - new_state.messages, - state.server_tools ++ state.liveview_tools, - state.provider, - effective_tool_choice - ) - - {:noreply, new_state} - end - - @ai_stream_events [ - :ai_text_chunk, - :ai_reasoning_chunk, - :ai_text_stream_finish, - :ai_tool_call_start, - :ai_tool_call_middle, - :ai_tool_call_end, - :tool_response - ] - - def handle_info({:start_new_ai_response, id}, state) do - starting_response = %{id: id, reasoning_content: "", content: "", tool_calls: []} - broadcast_ui(state.name, {:start_ai_response_stream, starting_response}) - - {:noreply, %{state | streaming_response: starting_response}} - end - - def handle_info( - msg, - %{streaming_response: %{id: current_id}} = state - ) - when is_tuple(msg) and tuple_size(msg) in [2, 3] and - elem(msg, 0) in @ai_stream_events and elem(msg, 1) != current_id do - Logger.warning( - "Received #{elem(msg, 0)} for id #{inspect(elem(msg, 1))} but current streaming response is for id #{inspect(current_id)}" - ) - - {:noreply, state} - end - - def handle_info({:ai_reasoning_chunk, _id, reasoning_content}, state) do - broadcast_ui(state.name, {:reasoning_chunk_content, reasoning_content}) - - {:noreply, - %{ - state - | streaming_response: %{ - state.streaming_response - | reasoning_content: state.streaming_response.reasoning_content <> reasoning_content - } - }} - end - - def handle_info({:ai_text_chunk, _id, text_content}, state) do - broadcast_ui(state.name, {:text_chunk_content, text_content}) - - {:noreply, - %{ - state - | streaming_response: %{ - state.streaming_response - | content: state.streaming_response.content <> text_content - } - }} - end - - def handle_info({:ai_text_stream_finish, _id}, state) do - Logger.info( - "AI stream finished for id #{state.streaming_response.id}, broadcasting end of AI response" - ) - - final_message = %{ - role: :assistant, - content: state.streaming_response.content, - reasoning_content: state.streaming_response.reasoning_content, - tool_calls: state.streaming_response.tool_calls - } - - broadcast_ui(state.name, {:end_ai_response, final_message}) - store_message(state.name, final_message) - - {:noreply, - %{ - state - | streaming_response: nil, - messages: state.messages ++ [final_message] - }} - end - - def handle_info( - {:ai_tool_call_start, _id, {tool_name, tool_args_start, tool_index, tool_call_id}}, - state - ) do - Logger.info("AI started tool call #{tool_name}") - - new_streaming_response = %{ - state.streaming_response - | tool_calls: - state.streaming_response.tool_calls ++ - [ - %{ - name: tool_name, - arguments: tool_args_start, - index: tool_index, - id: tool_call_id - } - ] - } - - {:noreply, %{state | streaming_response: new_streaming_response}} - end - - def handle_info({:ai_tool_call_middle, _id, {tool_args_diff, tool_index}}, state) do - new_streaming_response = %{ - state.streaming_response - | tool_calls: - Enum.map(state.streaming_response.tool_calls, fn - %{ - arguments: existing_args, - index: ^tool_index - } = tool_call -> - %{ - tool_call - | arguments: existing_args <> tool_args_diff - } - - other -> - other - end) - } - - {:noreply, %{state | streaming_response: new_streaming_response}} - end - - def handle_info({:ai_tool_call_end, id}, state) do - tool_request_message = %{ - role: :assistant, - content: state.streaming_response.content, - reasoning_content: state.streaming_response.reasoning_content, - tool_calls: state.streaming_response.tool_calls - } - - broadcast_ui(state.name, {:tool_request_message, tool_request_message}) - - {failed_call_messages, pending_call_ids} = - Enum.reduce(state.streaming_response.tool_calls, {[], []}, fn tool_call, - {failed, pending} -> - with {:ok, decoded_args} <- Jason.decode(tool_call.arguments), - tool when not is_nil(tool) <- - Enum.find(state.server_tools ++ state.liveview_tools, fn t -> - t.name == tool_call.name - end) do - tool.run_function.(id, tool_call.id, decoded_args) - {failed, [tool_call.id | pending]} - else - {:error, e} -> - error_msg = "Failed to decode tool arguments: #{inspect(e)}" - Logger.error("Tool call #{tool_call.name} failed: #{error_msg}") - {[%{role: :tool, content: error_msg, tool_call_id: tool_call.id} | failed], pending} - - nil -> - error_msg = "No tool definition found for #{tool_call.name}" - Logger.error(error_msg) - {[%{role: :tool, content: error_msg, tool_call_id: tool_call.id} | failed], pending} - end - end) - - store_message(state.name, [tool_request_message] ++ failed_call_messages) - - {:noreply, - %{ - state - | messages: state.messages ++ [tool_request_message] ++ failed_call_messages, - pending_tool_calls: pending_call_ids - }} - end - - def handle_info({:tool_response, _id, tool_call_id, result}, state) do - new_message = %{role: :tool, content: inspect(result), tool_call_id: tool_call_id} - - broadcast_ui(state.name, {:one_tool_finished, new_message}) - store_message(state.name, new_message) - - new_pending_tool_calls = - Enum.filter(state.pending_tool_calls, fn id -> id != tool_call_id end) - - new_streaming_response = - case new_pending_tool_calls do - [] -> - nil - - _ -> - state.streaming_response - end - - if new_pending_tool_calls == [] do - broadcast_ui(state.name, :tool_calls_finished) - - ElixirAi.ChatUtils.request_ai_response( - self(), - state.messages ++ [new_message], - state.server_tools ++ state.liveview_tools, - state.provider, - state.tool_choice - ) - end - - {:noreply, - %{ - state - | pending_tool_calls: new_pending_tool_calls, - streaming_response: new_streaming_response, - messages: state.messages ++ [new_message] - }} - end - - def handle_info({:db_error, reason}, state) do - broadcast_ui(state.name, {:db_error, reason}) - {:noreply, state} - end - - def handle_info({:sql_result_validation_error, error}, state) do - Logger.error("ChatRunner received sql_result_validation_error: #{inspect(error)}") - broadcast_ui(state.name, {:db_error, "Schema validation error: #{inspect(error)}"}) - {:noreply, state} - end - - def handle_info({:store_message, _name, _message}, state) do - {:noreply, state} - end - - def handle_info({:ai_request_error, reason}, state) do - Logger.error("AI request error: #{inspect(reason)}") - broadcast_ui(state.name, {:ai_request_error, reason}) - {:noreply, %{state | streaming_response: nil, pending_tool_calls: []}} - end - - def handle_info({:DOWN, ref, :process, pid, _reason}, state) do - case Map.get(state.liveview_pids, pid) do - ^ref -> - Logger.info("ChatRunner #{state.name}: LiveView #{inspect(pid)} disconnected") - {:noreply, %{state | liveview_pids: Map.delete(state.liveview_pids, pid)}} - - _ -> - {:noreply, state} - end - end - - def handle_call(:get_conversation, _from, state) do - {:reply, state, state} - end - - def handle_call(:get_streaming_response, _from, state) do - {:reply, state.streaming_response, state} - end - - def handle_call(:get_liveview_pids, _from, state) do - {:reply, Map.keys(state.liveview_pids), state} - end - - def handle_call({:register_liveview_pid, liveview_pid}, _from, state) do - ref = Process.monitor(liveview_pid) - {:reply, :ok, %{state | liveview_pids: Map.put(state.liveview_pids, liveview_pid, ref)}} - end - - def handle_call({:deregister_liveview_pid, liveview_pid}, _from, state) do - case Map.pop(state.liveview_pids, liveview_pid) do - {nil, _} -> - {:reply, :ok, state} - - {ref, new_pids} -> - Process.demonitor(ref, [:flush]) - {:reply, :ok, %{state | liveview_pids: new_pids}} - end - end - - def handle_call({:set_tool_choice, tool_choice}, _from, state) do - Conversation.update_tool_choice(state.name, tool_choice) - {:reply, :ok, %{state | tool_choice: tool_choice}} - end - - def handle_call({:set_allowed_tools, tool_names}, _from, state) do - Conversation.update_allowed_tools(state.name, tool_names) - server_tools = AiTools.build_server_tools(self(), tool_names) - liveview_tools = AiTools.build_liveview_tools(self(), tool_names) - - {:reply, :ok, - %{ - state - | allowed_tools: tool_names, - server_tools: server_tools, - liveview_tools: liveview_tools - }} - end - - defp broadcast_ui(name, msg), - do: Phoenix.PubSub.broadcast(ElixirAi.PubSub, chat_topic(name), msg) - - defp store_message(name, messages) when is_list(messages) do - Enum.each(messages, &store_message(name, &1)) - messages - end - - defp store_message(name, message) do - Phoenix.PubSub.broadcast( - ElixirAi.PubSub, - conversation_message_topic(name), - {:store_message, name, message} - ) - - message - end -end diff --git a/lib/elixir_ai/chat_runner/chat_runner.ex b/lib/elixir_ai/chat_runner/chat_runner.ex new file mode 100644 index 0000000..e54d9d9 --- /dev/null +++ b/lib/elixir_ai/chat_runner/chat_runner.ex @@ -0,0 +1,156 @@ +defmodule ElixirAi.ChatRunner do + require Logger + use GenServer + alias ElixirAi.{AiTools, Conversation, Message} + import ElixirAi.PubsubTopics + import ElixirAi.ChatRunner.OutboundHelpers + + alias ElixirAi.ChatRunner.{ + ConversationCalls, + ErrorHandler, + LiveviewSession, + StreamHandler, + ToolConfig + } + + @ai_stream_events [ + :ai_text_chunk, + :ai_reasoning_chunk, + :ai_text_stream_finish, + :ai_tool_call_start, + :ai_tool_call_middle, + :ai_tool_call_end, + :tool_response + ] + + defp via(name), do: {:via, Horde.Registry, {ElixirAi.ChatRegistry, name}} + + def new_user_message(name, text_content, opts \\ []) do + tool_choice = Keyword.get(opts, :tool_choice) + GenServer.cast(via(name), {:conversation, {:user_message, text_content, tool_choice}}) + end + + def set_allowed_tools(name, tool_names) when is_list(tool_names) do + GenServer.call(via(name), {:tool_config, {:set_allowed_tools, tool_names}}) + end + + def set_tool_choice(name, tool_choice) when tool_choice in ["auto", "none", "required"] do + GenServer.call(via(name), {:tool_config, {:set_tool_choice, tool_choice}}) + end + + def register_liveview_pid(name, liveview_pid) when is_pid(liveview_pid) do + GenServer.call(via(name), {:session, {:register_liveview_pid, liveview_pid}}) + end + + def deregister_liveview_pid(name, liveview_pid) when is_pid(liveview_pid) do + GenServer.call(via(name), {:session, {:deregister_liveview_pid, liveview_pid}}) + end + + def get_conversation(name) do + GenServer.call(via(name), {:conversation, :get_conversation}) + end + + def get_streaming_response(name) do + GenServer.call(via(name), {:conversation, :get_streaming_response}) + end + + def start_link(name: name) do + GenServer.start_link(__MODULE__, name, name: via(name)) + end + + def init(name) do + Phoenix.PubSub.subscribe(ElixirAi.PubSub, conversation_message_topic(name)) + + messages = + case Conversation.find_id(name) do + {:ok, conv_id} -> + Message.load_for_conversation(conv_id, topic: conversation_message_topic(name)) + + _ -> + [] + end + + last_message = List.last(messages) + + provider = + case Conversation.find_provider(name) do + {:ok, p} -> p + _ -> nil + end + + allowed_tools = + case Conversation.find_allowed_tools(name) do + {:ok, tools} -> tools + _ -> AiTools.all_tool_names() + end + + tool_choice = + case Conversation.find_tool_choice(name) do + {:ok, tc} -> tc + _ -> "auto" + end + + server_tools = AiTools.build_server_tools(self(), allowed_tools) + liveview_tools = AiTools.build_liveview_tools(self(), allowed_tools) + + if last_message && last_message.role == :user do + Logger.info( + "Last message role was #{last_message.role}, requesting AI response for conversation #{name}" + ) + + broadcast_ui(name, :recovery_restart) + + ElixirAi.ChatUtils.request_ai_response( + self(), + messages, + server_tools ++ liveview_tools, + provider, + tool_choice + ) + end + + {:ok, + %{ + name: name, + messages: messages, + streaming_response: nil, + pending_tool_calls: [], + allowed_tools: allowed_tools, + tool_choice: tool_choice, + server_tools: server_tools, + liveview_tools: liveview_tools, + provider: provider, + liveview_pids: %{} + }} + end + + def handle_cast({:conversation, inner}, state), do: ConversationCalls.handle_cast(inner, state) + + def handle_info( + {:stream, msg}, + %{streaming_response: %{id: current_id}} = state + ) + when is_tuple(msg) and tuple_size(msg) in [2, 3] and + elem(msg, 0) in @ai_stream_events and elem(msg, 1) != current_id do + Logger.warning( + "Received #{elem(msg, 0)} for id #{inspect(elem(msg, 1))} but current streaming response is for id #{inspect(current_id)}" + ) + + {:noreply, state} + end + + def handle_info({:stream, inner}, state), do: StreamHandler.handle(inner, state) + def handle_info({:error, inner}, state), do: ErrorHandler.handle(inner, state) + + def handle_info({:DOWN, ref, :process, pid, reason}, state), + do: LiveviewSession.handle_down(ref, pid, reason, state) + + def handle_call({:conversation, inner}, from, state), + do: ConversationCalls.handle_call(inner, from, state) + + def handle_call({:session, inner}, from, state), + do: LiveviewSession.handle_call(inner, from, state) + + def handle_call({:tool_config, inner}, from, state), + do: ToolConfig.handle_call(inner, from, state) +end diff --git a/lib/elixir_ai/chat_runner/conversation_calls.ex b/lib/elixir_ai/chat_runner/conversation_calls.ex new file mode 100644 index 0000000..78bb7c4 --- /dev/null +++ b/lib/elixir_ai/chat_runner/conversation_calls.ex @@ -0,0 +1,29 @@ +defmodule ElixirAi.ChatRunner.ConversationCalls do + import ElixirAi.ChatRunner.OutboundHelpers + + def handle_cast({:user_message, text_content, tool_choice_override}, state) do + effective_tool_choice = tool_choice_override || state.tool_choice + new_message = %{role: :user, content: text_content, tool_choice: tool_choice_override} + broadcast_ui(state.name, {:user_chat_message, new_message}) + store_message(state.name, new_message) + new_state = %{state | messages: state.messages ++ [new_message]} + + ElixirAi.ChatUtils.request_ai_response( + self(), + new_state.messages, + state.server_tools ++ state.liveview_tools, + state.provider, + effective_tool_choice + ) + + {:noreply, new_state} + end + + def handle_call(:get_conversation, _from, state) do + {:reply, state, state} + end + + def handle_call(:get_streaming_response, _from, state) do + {:reply, state.streaming_response, state} + end +end diff --git a/lib/elixir_ai/chat_runner/error_handler.ex b/lib/elixir_ai/chat_runner/error_handler.ex new file mode 100644 index 0000000..05c6669 --- /dev/null +++ b/lib/elixir_ai/chat_runner/error_handler.ex @@ -0,0 +1,19 @@ +defmodule ElixirAi.ChatRunner.ErrorHandler do + require Logger + import ElixirAi.ChatRunner.OutboundHelpers + + def handle({:db_error, reason}, state) do + broadcast_ui(state.name, {:db_error, reason}) + {:noreply, state} + end + + def handle({:sql_result_validation_error, error}, state) do + Logger.error("ChatRunner received sql_result_validation_error: #{inspect(error)}") + broadcast_ui(state.name, {:db_error, "Schema validation error: #{inspect(error)}"}) + {:noreply, state} + end + + def handle({:store_message, _name, _message}, state) do + {:noreply, state} + end +end diff --git a/lib/elixir_ai/chat_runner/liveview_session.ex b/lib/elixir_ai/chat_runner/liveview_session.ex new file mode 100644 index 0000000..027bc38 --- /dev/null +++ b/lib/elixir_ai/chat_runner/liveview_session.ex @@ -0,0 +1,34 @@ +defmodule ElixirAi.ChatRunner.LiveviewSession do + require Logger + + def handle_call(:get_liveview_pids, _from, state) do + {:reply, Map.keys(state.liveview_pids), state} + end + + def handle_call({:register_liveview_pid, liveview_pid}, _from, state) do + ref = Process.monitor(liveview_pid) + {:reply, :ok, %{state | liveview_pids: Map.put(state.liveview_pids, liveview_pid, ref)}} + end + + def handle_call({:deregister_liveview_pid, liveview_pid}, _from, state) do + case Map.pop(state.liveview_pids, liveview_pid) do + {nil, _} -> + {:reply, :ok, state} + + {ref, new_pids} -> + Process.demonitor(ref, [:flush]) + {:reply, :ok, %{state | liveview_pids: new_pids}} + end + end + + def handle_down(ref, pid, _reason, state) do + case Map.get(state.liveview_pids, pid) do + ^ref -> + Logger.info("ChatRunner #{state.name}: LiveView #{inspect(pid)} disconnected") + {:noreply, %{state | liveview_pids: Map.delete(state.liveview_pids, pid)}} + + _ -> + {:noreply, state} + end + end +end diff --git a/lib/elixir_ai/chat_runner/outbound_helpers.ex b/lib/elixir_ai/chat_runner/outbound_helpers.ex new file mode 100644 index 0000000..1dc8f37 --- /dev/null +++ b/lib/elixir_ai/chat_runner/outbound_helpers.ex @@ -0,0 +1,21 @@ +defmodule ElixirAi.ChatRunner.OutboundHelpers do + import ElixirAi.PubsubTopics + + def broadcast_ui(name, msg), + do: Phoenix.PubSub.broadcast(ElixirAi.PubSub, chat_topic(name), msg) + + def store_message(name, messages) when is_list(messages) do + Enum.each(messages, &store_message(name, &1)) + messages + end + + def store_message(name, message) do + Phoenix.PubSub.broadcast( + ElixirAi.PubSub, + conversation_message_topic(name), + {:error, {:store_message, name, message}} + ) + + message + end +end diff --git a/lib/elixir_ai/chat_runner/stream_handler.ex b/lib/elixir_ai/chat_runner/stream_handler.ex new file mode 100644 index 0000000..c506475 --- /dev/null +++ b/lib/elixir_ai/chat_runner/stream_handler.ex @@ -0,0 +1,183 @@ +defmodule ElixirAi.ChatRunner.StreamHandler do + require Logger + import ElixirAi.ChatRunner.OutboundHelpers + + def handle({:start_new_ai_response, id}, state) do + starting_response = %{id: id, reasoning_content: "", content: "", tool_calls: []} + broadcast_ui(state.name, {:start_ai_response_stream, starting_response}) + {:noreply, %{state | streaming_response: starting_response}} + end + + def handle({:ai_reasoning_chunk, _id, reasoning_content}, state) do + broadcast_ui(state.name, {:reasoning_chunk_content, reasoning_content}) + + {:noreply, + %{ + state + | streaming_response: %{ + state.streaming_response + | reasoning_content: state.streaming_response.reasoning_content <> reasoning_content + } + }} + end + + def handle({:ai_text_chunk, _id, text_content}, state) do + broadcast_ui(state.name, {:text_chunk_content, text_content}) + + {:noreply, + %{ + state + | streaming_response: %{ + state.streaming_response + | content: state.streaming_response.content <> text_content + } + }} + end + + def handle({:ai_text_stream_finish, _id}, state) do + Logger.info( + "AI stream finished for id #{state.streaming_response.id}, broadcasting end of AI response" + ) + + final_message = %{ + role: :assistant, + content: state.streaming_response.content, + reasoning_content: state.streaming_response.reasoning_content, + tool_calls: state.streaming_response.tool_calls + } + + broadcast_ui(state.name, {:end_ai_response, final_message}) + store_message(state.name, final_message) + + {:noreply, + %{ + state + | streaming_response: nil, + messages: state.messages ++ [final_message] + }} + end + + def handle( + {:ai_tool_call_start, _id, {tool_name, tool_args_start, tool_index, tool_call_id}}, + state + ) do + Logger.info("AI started tool call #{tool_name}") + + new_streaming_response = %{ + state.streaming_response + | tool_calls: + state.streaming_response.tool_calls ++ + [ + %{ + name: tool_name, + arguments: tool_args_start, + index: tool_index, + id: tool_call_id + } + ] + } + + {:noreply, %{state | streaming_response: new_streaming_response}} + end + + def handle({:ai_tool_call_middle, _id, {tool_args_diff, tool_index}}, state) do + new_streaming_response = %{ + state.streaming_response + | tool_calls: + Enum.map(state.streaming_response.tool_calls, fn + %{arguments: existing_args, index: ^tool_index} = tool_call -> + %{tool_call | arguments: existing_args <> tool_args_diff} + + other -> + other + end) + } + + {:noreply, %{state | streaming_response: new_streaming_response}} + end + + def handle({:ai_tool_call_end, id}, state) do + tool_request_message = %{ + role: :assistant, + content: state.streaming_response.content, + reasoning_content: state.streaming_response.reasoning_content, + tool_calls: state.streaming_response.tool_calls + } + + broadcast_ui(state.name, {:tool_request_message, tool_request_message}) + + {failed_call_messages, pending_call_ids} = + Enum.reduce(state.streaming_response.tool_calls, {[], []}, fn tool_call, + {failed, pending} -> + with {:ok, decoded_args} <- Jason.decode(tool_call.arguments), + tool when not is_nil(tool) <- + Enum.find(state.server_tools ++ state.liveview_tools, fn t -> + t.name == tool_call.name + end) do + tool.run_function.(id, tool_call.id, decoded_args) + {failed, [tool_call.id | pending]} + else + {:error, e} -> + error_msg = "Failed to decode tool arguments: #{inspect(e)}" + Logger.error("Tool call #{tool_call.name} failed: #{error_msg}") + {[%{role: :tool, content: error_msg, tool_call_id: tool_call.id} | failed], pending} + + nil -> + error_msg = "No tool definition found for #{tool_call.name}" + Logger.error(error_msg) + {[%{role: :tool, content: error_msg, tool_call_id: tool_call.id} | failed], pending} + end + end) + + store_message(state.name, [tool_request_message] ++ failed_call_messages) + + {:noreply, + %{ + state + | messages: state.messages ++ [tool_request_message] ++ failed_call_messages, + pending_tool_calls: pending_call_ids + }} + end + + def handle({:tool_response, _id, tool_call_id, result}, state) do + new_message = %{role: :tool, content: inspect(result), tool_call_id: tool_call_id} + + broadcast_ui(state.name, {:one_tool_finished, new_message}) + store_message(state.name, new_message) + + new_pending_tool_calls = + Enum.filter(state.pending_tool_calls, fn id -> id != tool_call_id end) + + new_streaming_response = + case new_pending_tool_calls do + [] -> nil + _ -> state.streaming_response + end + + if new_pending_tool_calls == [] do + broadcast_ui(state.name, :tool_calls_finished) + + ElixirAi.ChatUtils.request_ai_response( + self(), + state.messages ++ [new_message], + state.server_tools ++ state.liveview_tools, + state.provider, + state.tool_choice + ) + end + + {:noreply, + %{ + state + | pending_tool_calls: new_pending_tool_calls, + streaming_response: new_streaming_response, + messages: state.messages ++ [new_message] + }} + end + + def handle({:ai_request_error, reason}, state) do + Logger.error("AI request error: #{inspect(reason)}") + broadcast_ui(state.name, {:ai_request_error, reason}) + {:noreply, %{state | streaming_response: nil, pending_tool_calls: []}} + end +end diff --git a/lib/elixir_ai/chat_runner/tool_config.ex b/lib/elixir_ai/chat_runner/tool_config.ex new file mode 100644 index 0000000..852a4bb --- /dev/null +++ b/lib/elixir_ai/chat_runner/tool_config.ex @@ -0,0 +1,22 @@ +defmodule ElixirAi.ChatRunner.ToolConfig do + alias ElixirAi.{AiTools, Conversation} + + def handle_call({:set_tool_choice, tool_choice}, _from, state) do + Conversation.update_tool_choice(state.name, tool_choice) + {:reply, :ok, %{state | tool_choice: tool_choice}} + end + + def handle_call({:set_allowed_tools, tool_names}, _from, state) do + Conversation.update_allowed_tools(state.name, tool_names) + server_tools = AiTools.build_server_tools(self(), tool_names) + liveview_tools = AiTools.build_liveview_tools(self(), tool_names) + + {:reply, :ok, + %{ + state + | allowed_tools: tool_names, + server_tools: server_tools, + liveview_tools: liveview_tools + }} + end +end diff --git a/lib/elixir_ai/conversation_manager.ex b/lib/elixir_ai/conversation_manager.ex index d50692f..84860f2 100644 --- a/lib/elixir_ai/conversation_manager.ex +++ b/lib/elixir_ai/conversation_manager.ex @@ -109,17 +109,20 @@ defmodule ElixirAi.ConversationManager do {:noreply, %{state | runners: runners}} end - def handle_info({:db_error, reason}, state) do + def handle_info({:error, {:db_error, reason}}, state) do Logger.error("ConversationManager received db_error: #{inspect(reason)}") {:noreply, state} end - def handle_info({:sql_result_validation_error, error}, state) do + def handle_info({:error, {:sql_result_validation_error, error}}, state) do Logger.error("ConversationManager received sql_result_validation_error: #{inspect(error)}") {:noreply, state} end - def handle_info({:store_message, name, message}, %{conversations: conversations} = state) do + def handle_info( + {:error, {:store_message, name, message}}, + %{conversations: conversations} = state + ) do case Conversation.find_id(name) do {:ok, conv_id} -> Message.insert(conv_id, message, topic: conversation_message_topic(name)) @@ -177,7 +180,7 @@ defmodule ElixirAi.ConversationManager do case start_and_subscribe(name, state) do {:ok, pid, new_subscriptions, new_runners} -> new_state = %{state | subscriptions: new_subscriptions, runners: new_runners} - conversation = GenServer.call(pid, :get_conversation) + conversation = GenServer.call(pid, {:conversation, :get_conversation}) {:reply, {:ok, Map.put(conversation, :runner_pid, pid)}, new_state} {:error, _reason} = error -> diff --git a/lib/elixir_ai/data/db_helpers.ex b/lib/elixir_ai/data/db_helpers.ex index fcc2b2e..2f064c1 100644 --- a/lib/elixir_ai/data/db_helpers.ex +++ b/lib/elixir_ai/data/db_helpers.ex @@ -28,7 +28,7 @@ defmodule ElixirAi.Data.DbHelpers do Phoenix.PubSub.broadcast( ElixirAi.PubSub, topic, - {:db_error, Exception.message(exception)} + {:error, {:db_error, Exception.message(exception)}} ) {:error, :db_error} @@ -55,7 +55,13 @@ defmodule ElixirAi.Data.DbHelpers do error -> Logger.error("Validation error: #{inspect(error)}") - Phoenix.PubSub.broadcast(ElixirAi.PubSub, topic, {:sql_result_validation_error, error}) + + Phoenix.PubSub.broadcast( + ElixirAi.PubSub, + topic, + {:error, {:sql_result_validation_error, error}} + ) + error end) end diff --git a/lib/elixir_ai_web/chat/chat_live.ex b/lib/elixir_ai_web/chat/chat_live.ex index e12a9bd..65cd9bc 100644 --- a/lib/elixir_ai_web/chat/chat_live.ex +++ b/lib/elixir_ai_web/chat/chat_live.ex @@ -133,7 +133,7 @@ defmodule ElixirAiWeb.ChatLive do def handle_info(:sync_streaming, %{assigns: %{runner_pid: pid}} = socket) when is_pid(pid) do - case GenServer.call(pid, :get_streaming_response) do + case GenServer.call(pid, {:conversation, :get_streaming_response}) do nil -> {:noreply, assign(socket, streaming_response: nil)} @@ -285,7 +285,7 @@ defmodule ElixirAiWeb.ChatLive do end defp get_snapshot(%{assigns: %{runner_pid: pid}} = _socket) when is_pid(pid) do - case GenServer.call(pid, :get_streaming_response) do + case GenServer.call(pid, {:conversation, :get_streaming_response}) do nil -> %{id: nil, content: "", reasoning_content: "", tool_calls: []} snapshot -> snapshot end diff --git a/lib/elixir_ai_web/voice/voice_live.ex b/lib/elixir_ai_web/voice/voice_live.ex index 68cb64a..7ec9806 100644 --- a/lib/elixir_ai_web/voice/voice_live.ex +++ b/lib/elixir_ai_web/voice/voice_live.ex @@ -3,85 +3,125 @@ defmodule ElixirAiWeb.VoiceLive do require Logger def mount(_params, _session, socket) do - {:ok, assign(socket, state: :idle, transcription: nil), layout: false} + {:ok, assign(socket, state: :idle, transcription: nil, expanded: false), layout: false} end def render(assigns) do ~H"""