more improvements on tool calling

This commit is contained in:
2026-03-06 09:08:16 -07:00
parent b89d4e5a28
commit 7c7e763809
7 changed files with 275 additions and 142 deletions

View File

@@ -37,7 +37,7 @@ defmodule ElixirAi.ChatRunner do
},
"read_thing" => %{
definition: ElixirAi.ToolTesting.read_thing_definition("read_thing"),
function: &ElixirAi.ToolTesting.get_thing/0
function: &ElixirAi.ToolTesting.get_thing/1
}
}
end
@@ -102,6 +102,7 @@ defmodule ElixirAi.ChatRunner do
end
def handle_info({:ai_stream_finish, _id}, state) do
Logger.info("AI stream finished for id #{state.streaming_response.id}, broadcasting end of AI response")
broadcast(:end_ai_response)
final_message = %{
@@ -120,7 +121,10 @@ defmodule ElixirAi.ChatRunner do
}}
end
def handle_info({:ai_tool_call_start, _id, {tool_name, tool_args_start, tool_index}}, state) do
def handle_info(
{:ai_tool_call_start, _id, {tool_name, tool_args_start, tool_index, tool_call_id}},
state
) do
Logger.info("AI started tool call #{tool_name}")
new_streaming_response = %{
@@ -131,7 +135,8 @@ defmodule ElixirAi.ChatRunner do
%{
name: tool_name,
arguments: tool_args_start,
index: tool_index
index: tool_index,
id: tool_call_id
}
]
}
@@ -161,68 +166,45 @@ defmodule ElixirAi.ChatRunner do
{:noreply, %{state | streaming_response: new_streaming_response}}
end
def handle_info({:ai_tool_call_end, _id, tool_index}, state) do
def handle_info({:ai_tool_call_end, _id}, state) do
Logger.info("ending tool call with tools: #{inspect(state.streaming_response.tool_calls)}")
tool_calls =
Enum.map(state.streaming_response.tool_calls, fn
%{
arguments: existing_args,
index: ^tool_index
} = tool_call ->
case Jason.decode(existing_args) do
{:ok, decoded_args} ->
tool_function = tools()[tool_call.name].function
res = tool_function.(decoded_args)
Enum.map(state.streaming_response.tool_calls, fn tool_call ->
case Jason.decode(tool_call.arguments) do
{:ok, decoded_args} ->
tool_function = tools()[tool_call.name].function
res = tool_function.(decoded_args)
Map.put(tool_call, :result, res)
Map.put(tool_call, :result, res)
{:error, e} ->
Map.put(tool_call, :error, "Failed to decode tool arguments: #{inspect(e)}")
end
other ->
other
{:error, e} ->
Map.put(tool_call, :error, "Failed to decode tool arguments: #{inspect(e)}")
end
end)
all_tool_calls_finished =
Enum.all?(tool_calls, fn call ->
Map.has_key?(call, :result) or Map.has_key?(call, :error)
tool_request_message = %{
role: :assistant,
content: state.streaming_response.content,
reasoning_content: state.streaming_response.reasoning_content,
tool_calls: tool_calls
}
result_messages =
Enum.map(tool_calls, fn call ->
if Map.has_key?(call, :result) do
%{role: :tool, content: "#{inspect(call.result)}", tool_call_id: call.id}
else
%{role: :tool, content: "Error in #{call.name}: #{call.error}", tool_call_id: call.id}
end
end)
state =
case all_tool_calls_finished do
true ->
Logger.info("All tool calls finished, broadcasting updated tool calls with results")
new_messages = [tool_request_message] ++ result_messages
new_message = %{
role: :assistant,
content: state.streaming_response.content,
reasoning_content: state.streaming_response.reasoning_content,
tool_calls: tool_calls
}
Logger.info("All tool calls finished, broadcasting updated tool calls with results")
broadcast({:tool_calls_finished, new_messages})
new_state = %{
state
| messages:
state.messages ++
[
new_message
],
streaming_response: nil
}
broadcast({:tool_calls_finished, new_message})
false ->
%{
state
| streaming_response: %{
state.streaming_response
| tool_calls: tool_calls
}
}
end
{:noreply, state}
{:noreply,
%{state | messages: state.messages ++ new_messages, streaming_response: nil}}
end
def handle_call(:get_conversation, _from, state) do