more improvements on tool calling
This commit is contained in:
@@ -19,6 +19,7 @@ defmodule ElixirAi.ChatUtils do
|
||||
|
||||
headers = [{"authorization", "Bearer #{api_key}"}]
|
||||
|
||||
Logger.info("sending AI request with body: #{inspect(body)}")
|
||||
case Req.post(api_url,
|
||||
json: body,
|
||||
headers: headers,
|
||||
@@ -39,6 +40,28 @@ defmodule ElixirAi.ChatUtils do
|
||||
end)
|
||||
end
|
||||
|
||||
def api_message(%{role: :assistant, tool_calls: [_ | _] = tool_calls} = msg) do
|
||||
%{
|
||||
role: "assistant",
|
||||
content: Map.get(msg, :content, ""),
|
||||
tool_calls:
|
||||
Enum.map(tool_calls, fn call ->
|
||||
%{
|
||||
id: call.id,
|
||||
type: "function",
|
||||
function: %{
|
||||
name: call.name,
|
||||
arguments: call.arguments
|
||||
}
|
||||
}
|
||||
end)
|
||||
}
|
||||
end
|
||||
|
||||
def api_message(%{role: :tool, tool_call_id: tool_call_id, content: content}) do
|
||||
%{role: "tool", tool_call_id: tool_call_id, content: content}
|
||||
end
|
||||
|
||||
def api_message(%{role: role, content: content}) do
|
||||
%{role: Atom.to_string(role), content: content}
|
||||
end
|
||||
|
||||
@@ -76,72 +76,49 @@ defmodule ElixirAi.AiUtils.StreamLineUtils do
|
||||
)
|
||||
end
|
||||
|
||||
# start tool call
|
||||
# start and middle tool call
|
||||
def handle_stream_line(server, %{
|
||||
"choices" => [
|
||||
%{
|
||||
"delta" => %{
|
||||
"tool_calls" => [
|
||||
%{
|
||||
"function" => %{
|
||||
"name" => tool_name,
|
||||
"arguments" => tool_args_start
|
||||
}
|
||||
}
|
||||
]
|
||||
"tool_calls" => tool_calls
|
||||
},
|
||||
"finish_reason" => nil,
|
||||
"index" => tool_index
|
||||
"finish_reason" => nil
|
||||
}
|
||||
],
|
||||
"id" => id
|
||||
}) do
|
||||
send(
|
||||
server,
|
||||
{:ai_tool_call_start, id, {tool_name, tool_args_start, tool_index}}
|
||||
)
|
||||
end
|
||||
})
|
||||
when is_list(tool_calls) do
|
||||
Enum.each(tool_calls, fn
|
||||
%{
|
||||
"id" => tool_call_id,
|
||||
"index" => tool_index,
|
||||
"type" => "function",
|
||||
"function" => %{"name" => tool_name, "arguments" => tool_args_start}
|
||||
} ->
|
||||
Logger.info("Received tool call start for tool #{tool_name}")
|
||||
|
||||
# middle tool call
|
||||
def handle_stream_line(server, %{
|
||||
"choices" => [
|
||||
%{
|
||||
"delta" => %{
|
||||
"tool_calls" => [
|
||||
%{
|
||||
"function" => %{
|
||||
"arguments" => tool_args_diff
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason" => nil,
|
||||
"index" => tool_index
|
||||
}
|
||||
],
|
||||
"id" => id
|
||||
}) do
|
||||
send(
|
||||
server,
|
||||
{:ai_tool_call_middle, id, {tool_args_diff, tool_index}}
|
||||
)
|
||||
send(
|
||||
server,
|
||||
{:ai_tool_call_start, id, {tool_name, tool_args_start, tool_index, tool_call_id}}
|
||||
)
|
||||
|
||||
%{"index" => tool_index, "function" => %{"arguments" => tool_args_diff}} ->
|
||||
Logger.info("Received tool call middle for index #{tool_index}")
|
||||
send(server, {:ai_tool_call_middle, id, {tool_args_diff, tool_index}})
|
||||
|
||||
other ->
|
||||
Logger.warning("Unmatched tool call item: #{inspect(other)}")
|
||||
end)
|
||||
end
|
||||
|
||||
# end tool call
|
||||
def handle_stream_line(server, %{
|
||||
"choices" => [
|
||||
%{
|
||||
"delta" => %{},
|
||||
"finish_reason" => "tool_calls",
|
||||
"index" => tool_index
|
||||
}
|
||||
],
|
||||
"choices" => [%{"finish_reason" => "tool_calls"}],
|
||||
"id" => id
|
||||
}) do
|
||||
send(
|
||||
server,
|
||||
{:ai_tool_call_end, id, tool_index}
|
||||
)
|
||||
Logger.info("Received tool call end")
|
||||
send(server, {:ai_tool_call_end, id})
|
||||
end
|
||||
|
||||
def handle_stream_line(_server, %{"error" => error_info}) do
|
||||
|
||||
@@ -9,6 +9,7 @@ defmodule ElixirAi.Application do
|
||||
{DNSCluster, query: Application.get_env(:elixir_ai, :dns_cluster_query) || :ignore},
|
||||
{Phoenix.PubSub, name: ElixirAi.PubSub},
|
||||
ElixirAi.ChatRunner,
|
||||
ElixirAi.ToolTesting,
|
||||
ElixirAiWeb.Endpoint
|
||||
]
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ defmodule ElixirAi.ChatRunner do
|
||||
},
|
||||
"read_thing" => %{
|
||||
definition: ElixirAi.ToolTesting.read_thing_definition("read_thing"),
|
||||
function: &ElixirAi.ToolTesting.get_thing/0
|
||||
function: &ElixirAi.ToolTesting.get_thing/1
|
||||
}
|
||||
}
|
||||
end
|
||||
@@ -102,6 +102,7 @@ defmodule ElixirAi.ChatRunner do
|
||||
end
|
||||
|
||||
def handle_info({:ai_stream_finish, _id}, state) do
|
||||
Logger.info("AI stream finished for id #{state.streaming_response.id}, broadcasting end of AI response")
|
||||
broadcast(:end_ai_response)
|
||||
|
||||
final_message = %{
|
||||
@@ -120,7 +121,10 @@ defmodule ElixirAi.ChatRunner do
|
||||
}}
|
||||
end
|
||||
|
||||
def handle_info({:ai_tool_call_start, _id, {tool_name, tool_args_start, tool_index}}, state) do
|
||||
def handle_info(
|
||||
{:ai_tool_call_start, _id, {tool_name, tool_args_start, tool_index, tool_call_id}},
|
||||
state
|
||||
) do
|
||||
Logger.info("AI started tool call #{tool_name}")
|
||||
|
||||
new_streaming_response = %{
|
||||
@@ -131,7 +135,8 @@ defmodule ElixirAi.ChatRunner do
|
||||
%{
|
||||
name: tool_name,
|
||||
arguments: tool_args_start,
|
||||
index: tool_index
|
||||
index: tool_index,
|
||||
id: tool_call_id
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -161,68 +166,45 @@ defmodule ElixirAi.ChatRunner do
|
||||
{:noreply, %{state | streaming_response: new_streaming_response}}
|
||||
end
|
||||
|
||||
def handle_info({:ai_tool_call_end, _id, tool_index}, state) do
|
||||
def handle_info({:ai_tool_call_end, _id}, state) do
|
||||
Logger.info("ending tool call with tools: #{inspect(state.streaming_response.tool_calls)}")
|
||||
|
||||
tool_calls =
|
||||
Enum.map(state.streaming_response.tool_calls, fn
|
||||
%{
|
||||
arguments: existing_args,
|
||||
index: ^tool_index
|
||||
} = tool_call ->
|
||||
case Jason.decode(existing_args) do
|
||||
{:ok, decoded_args} ->
|
||||
tool_function = tools()[tool_call.name].function
|
||||
res = tool_function.(decoded_args)
|
||||
Enum.map(state.streaming_response.tool_calls, fn tool_call ->
|
||||
case Jason.decode(tool_call.arguments) do
|
||||
{:ok, decoded_args} ->
|
||||
tool_function = tools()[tool_call.name].function
|
||||
res = tool_function.(decoded_args)
|
||||
Map.put(tool_call, :result, res)
|
||||
|
||||
Map.put(tool_call, :result, res)
|
||||
|
||||
{:error, e} ->
|
||||
Map.put(tool_call, :error, "Failed to decode tool arguments: #{inspect(e)}")
|
||||
end
|
||||
|
||||
other ->
|
||||
other
|
||||
{:error, e} ->
|
||||
Map.put(tool_call, :error, "Failed to decode tool arguments: #{inspect(e)}")
|
||||
end
|
||||
end)
|
||||
|
||||
all_tool_calls_finished =
|
||||
Enum.all?(tool_calls, fn call ->
|
||||
Map.has_key?(call, :result) or Map.has_key?(call, :error)
|
||||
tool_request_message = %{
|
||||
role: :assistant,
|
||||
content: state.streaming_response.content,
|
||||
reasoning_content: state.streaming_response.reasoning_content,
|
||||
tool_calls: tool_calls
|
||||
}
|
||||
|
||||
result_messages =
|
||||
Enum.map(tool_calls, fn call ->
|
||||
if Map.has_key?(call, :result) do
|
||||
%{role: :tool, content: "#{inspect(call.result)}", tool_call_id: call.id}
|
||||
else
|
||||
%{role: :tool, content: "Error in #{call.name}: #{call.error}", tool_call_id: call.id}
|
||||
end
|
||||
end)
|
||||
|
||||
state =
|
||||
case all_tool_calls_finished do
|
||||
true ->
|
||||
Logger.info("All tool calls finished, broadcasting updated tool calls with results")
|
||||
new_messages = [tool_request_message] ++ result_messages
|
||||
|
||||
new_message = %{
|
||||
role: :assistant,
|
||||
content: state.streaming_response.content,
|
||||
reasoning_content: state.streaming_response.reasoning_content,
|
||||
tool_calls: tool_calls
|
||||
}
|
||||
Logger.info("All tool calls finished, broadcasting updated tool calls with results")
|
||||
broadcast({:tool_calls_finished, new_messages})
|
||||
|
||||
new_state = %{
|
||||
state
|
||||
| messages:
|
||||
state.messages ++
|
||||
[
|
||||
new_message
|
||||
],
|
||||
streaming_response: nil
|
||||
}
|
||||
|
||||
broadcast({:tool_calls_finished, new_message})
|
||||
|
||||
false ->
|
||||
%{
|
||||
state
|
||||
| streaming_response: %{
|
||||
state.streaming_response
|
||||
| tool_calls: tool_calls
|
||||
}
|
||||
}
|
||||
end
|
||||
|
||||
{:noreply, state}
|
||||
{:noreply,
|
||||
%{state | messages: state.messages ++ new_messages, streaming_response: nil}}
|
||||
end
|
||||
|
||||
def handle_call(:get_conversation, _from, state) do
|
||||
|
||||
@@ -9,6 +9,10 @@ defmodule ElixirAi.ToolTesting do
|
||||
GenServer.call(__MODULE__, :get_thing)
|
||||
end
|
||||
|
||||
def get_thing(_) do
|
||||
GenServer.call(__MODULE__, :get_thing)
|
||||
end
|
||||
|
||||
def store_thing_definition(name) do
|
||||
%{
|
||||
"type" => "function",
|
||||
|
||||
Reference in New Issue
Block a user