diff --git a/.github/workflows/pipeline.yml b/.github/workflows/pipeline.yml index bddcb3d..0cfe867 100644 --- a/.github/workflows/pipeline.yml +++ b/.github/workflows/pipeline.yml @@ -47,7 +47,7 @@ jobs: --from-literal=AI_TOKEN="$AI_TOKEN" kubectl create configmap db-schema \ - --from-file=schema.sql=schema.sql \ + --from-file=postgres/schema/ \ --namespace ai-ha-elixir \ --dry-run=client -o yaml | kubectl apply -f - diff --git a/docker-compose.yml b/docker-compose.yml index b3c482e..002ec69 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,7 +7,7 @@ services: POSTGRES_DB: elixir_ai_dev command: postgres -c hba_file=/etc/postgresql/pg_hba.conf volumes: - - ./schema.sql:/docker-entrypoint-initdb.d/schema.sql + - ./postgres/schema/:/docker-entrypoint-initdb.d/ - ./postgres/pg_hba.conf:/etc/postgresql/pg_hba.conf healthcheck: test: ["CMD-SHELL", "pg_isready -U elixir_ai -d elixir_ai_dev"] diff --git a/lib/elixir_ai/chat_runner.ex b/lib/elixir_ai/chat_runner.ex index 0519ea5..2c9d8e2 100644 --- a/lib/elixir_ai/chat_runner.ex +++ b/lib/elixir_ai/chat_runner.ex @@ -329,6 +329,12 @@ defmodule ElixirAi.ChatRunner do {:noreply, state} end + def handle_info({:sql_result_validation_error, error}, state) do + Logger.error("ChatRunner received sql_result_validation_error: #{inspect(error)}") + broadcast_ui(state.name, {:db_error, "Schema validation error: #{inspect(error)}"}) + {:noreply, state} + end + def handle_info({:store_message, _name, _message}, state) do {:noreply, state} end diff --git a/lib/elixir_ai/conversation_manager.ex b/lib/elixir_ai/conversation_manager.ex index 2a9df0a..8da6f68 100644 --- a/lib/elixir_ai/conversation_manager.ex +++ b/lib/elixir_ai/conversation_manager.ex @@ -100,6 +100,11 @@ defmodule ElixirAi.ConversationManager do {:noreply, state} end + def handle_info({:sql_result_validation_error, error}, state) do + Logger.error("ConversationManager received sql_result_validation_error: #{inspect(error)}") + {:noreply, state} + end + def handle_info({:store_message, name, message}, %{conversations: conversations} = state) do case Conversation.find_id(name) do {:ok, conv_id} -> diff --git a/lib/elixir_ai/data/message.ex b/lib/elixir_ai/data/message.ex index 245c030..1800a29 100644 --- a/lib/elixir_ai/data/message.ex +++ b/lib/elixir_ai/data/message.ex @@ -16,34 +16,93 @@ defmodule ElixirAi.Message do end end + # Row schemas for the three message tables + defp text_message_row_schema do + Zoi.object(%{ + id: Zoi.integer(), + prev_message_id: Zoi.nullish(Zoi.integer()), + prev_message_table: Zoi.nullish(Zoi.string()), + role: Zoi.string(), + content: Zoi.nullish(Zoi.string()), + reasoning_content: Zoi.nullish(Zoi.string()), + inserted_at: Zoi.any() + }) + end + + defp tool_call_request_row_schema do + Zoi.object(%{ + id: Zoi.integer(), + text_message_id: Zoi.integer(), + prev_message_id: Zoi.nullish(Zoi.integer()), + prev_message_table: Zoi.nullish(Zoi.string()), + tool_name: Zoi.string(), + tool_call_id: Zoi.string(), + arguments: Zoi.any(), + inserted_at: Zoi.any() + }) + end + + defp tool_response_row_schema do + Zoi.object(%{ + id: Zoi.integer(), + tool_call_id: Zoi.string(), + prev_message_id: Zoi.nullish(Zoi.integer()), + prev_message_table: Zoi.nullish(Zoi.string()), + content: Zoi.string(), + inserted_at: Zoi.any() + }) + end + def load_for_conversation(conversation_id, topic: topic) when is_binary(conversation_id) and byte_size(conversation_id) == 16 do - sql = """ - SELECT role, content, reasoning_content, tool_calls, tool_call_id - FROM messages - WHERE conversation_id = $(conversation_id) - ORDER BY id - """ + with text_messages when is_list(text_messages) <- fetch_text_messages(conversation_id, topic), + tool_call_msgs when is_list(tool_call_msgs) <- + fetch_tool_call_request_messages(conversation_id, topic), + tool_response_msgs when is_list(tool_response_msgs) <- + fetch_tool_response_messages(conversation_id, topic) do + tagged = + Enum.map(text_messages, &Map.put(&1, :_table, "text_messages")) ++ + Enum.map(tool_call_msgs, &Map.put(&1, :_table, "tool_calls_request_messages")) ++ + Enum.map(tool_response_msgs, &Map.put(&1, :_table, "tool_response_messages")) - params = %{"conversation_id" => conversation_id} + by_key = Map.new(tagged, fn row -> {{row._table, row.id}, row} end) - case DbHelpers.run_sql(sql, params, topic) do - {:error, :db_error} -> - [] + ordered = sort_by_prev_message(tagged, by_key) - rows -> - Enum.map(rows, fn row -> - decoded = decode_message(row) + Enum.map(ordered, fn row -> + case row._table do + "text_messages" -> + %MessageSchema{ + role: String.to_existing_atom(row.role), + content: row[:content], + reasoning_content: row[:reasoning_content], + tool_calls: [] + } - case Zoi.parse(MessageSchema.schema(), decoded) do - {:ok, _valid} -> - struct(MessageSchema, decoded) + "tool_calls_request_messages" -> + %MessageSchema{ + role: :assistant, + tool_calls: [ + %{ + id: row.tool_call_id, + name: row.tool_name, + arguments: row.arguments + } + ] + } - {:error, errors} -> - Logger.error("Invalid message data from DB: #{inspect(errors)}") - raise ArgumentError, "Invalid message data: #{inspect(errors)}" - end - end) + "tool_response_messages" -> + %MessageSchema{ + role: :tool, + content: row.content, + tool_call_id: row.tool_call_id + } + end + end) + |> Enum.map(&drop_nil_fields(Map.from_struct(&1))) + |> Enum.map(&struct(MessageSchema, &1)) + else + _ -> [] end end @@ -57,45 +116,88 @@ defmodule ElixirAi.Message do end end - def insert(conversation_id, message, topic: topic) - when is_binary(conversation_id) and byte_size(conversation_id) == 16 do + defp fetch_text_messages(conversation_id, topic) do sql = """ - INSERT INTO messages ( - conversation_id, - role, - content, - reasoning_content, - tool_calls, - tool_call_id, - inserted_at - ) VALUES ( - $(conversation_id), - $(role), - $(content), - $(reasoning_content), - $(tool_calls)::jsonb, - $(tool_call_id), - $(inserted_at) - ) + SELECT + tm.id, + tm.prev_message_id, + tm.prev_message_table, + tm.role, + tm.content, + tm.reasoning_content, + tm.inserted_at + FROM text_messages tm + WHERE tm.conversation_id = $(conversation_id) """ - params = %{ - "conversation_id" => conversation_id, - "role" => to_string(message.role), - "content" => message[:content], - "reasoning_content" => message[:reasoning_content], - "tool_calls" => encode_tool_calls(message[:tool_calls]), - "tool_call_id" => message[:tool_call_id], - "inserted_at" => DateTime.truncate(DateTime.utc_now(), :second) - } + DbHelpers.run_sql( + sql, + %{"conversation_id" => conversation_id}, + topic, + text_message_row_schema() + ) || [] + end - case DbHelpers.run_sql(sql, params, topic) do - {:error, :db_error} -> - {:error, :db_error} + defp fetch_tool_call_request_messages(conversation_id, topic) do + sql = """ + SELECT + tc.id, + tc.text_message_id, + tc.prev_message_id, + tc.prev_message_table, + tc.tool_name, + tc.tool_call_id, + tc.arguments, + tc.inserted_at + FROM tool_calls_request_messages tc + JOIN text_messages tm ON tc.text_message_id = tm.id + WHERE tm.conversation_id = $(conversation_id) + """ - _result -> - # Logger.debug("Inserted message for conversation_id=#{Ecto.UUID.cast!(conversation_id)}") - {:ok, 1} + DbHelpers.run_sql( + sql, + %{"conversation_id" => conversation_id}, + topic, + tool_call_request_row_schema() + ) || [] + end + + defp fetch_tool_response_messages(conversation_id, topic) do + sql = """ + SELECT + tr.id, + tr.tool_call_id, + tr.prev_message_id, + tr.prev_message_table, + tr.content, + tr.inserted_at + FROM tool_response_messages tr + JOIN tool_calls_request_messages tc ON tr.tool_call_id = tc.tool_call_id + JOIN text_messages tm ON tc.text_message_id = tm.id + WHERE tm.conversation_id = $(conversation_id) + """ + + DbHelpers.run_sql( + sql, + %{"conversation_id" => conversation_id}, + topic, + tool_response_row_schema() + ) || [] + end + + def insert(conversation_id, message, topic: topic) + when is_binary(conversation_id) and byte_size(conversation_id) == 16 do + timestamp = DateTime.truncate(DateTime.utc_now(), :second) + + case message.role do + :tool -> + insert_tool_response(message, timestamp, topic) + + :assistant -> + insert_assistant_message(conversation_id, message, timestamp, topic) + + :user -> + insert_user_message(conversation_id, message, timestamp, topic) end end @@ -110,37 +212,237 @@ defmodule ElixirAi.Message do end end - defp encode_tool_calls(nil), do: nil - defp encode_tool_calls(calls), do: Jason.encode!(calls) + defp insert_user_message(conversation_id, message, timestamp, topic) do + {prev_id, prev_table} = get_last_message_ref(conversation_id, topic) + + sql = """ + INSERT INTO text_messages ( + conversation_id, + prev_message_id, + prev_message_table, + role, + content, + inserted_at + ) VALUES ( + $(conversation_id), + $(prev_message_id), + $(prev_message_table), + $(role), + $(content), + $(inserted_at) + ) + """ + + params = %{ + "conversation_id" => conversation_id, + "prev_message_id" => prev_id, + "prev_message_table" => prev_table, + "role" => "user", + "content" => message[:content], + "inserted_at" => timestamp + } + + case DbHelpers.run_sql(sql, params, topic) do + {:error, :db_error} -> {:error, :db_error} + _result -> {:ok, 1} + end + end + + defp insert_assistant_message(conversation_id, message, timestamp, topic) do + {prev_id, prev_table} = get_last_message_ref(conversation_id, topic) + + message_sql = """ + INSERT INTO text_messages ( + conversation_id, + prev_message_id, + prev_message_table, + role, + content, + reasoning_content, + inserted_at + ) VALUES ( + $(conversation_id), + $(prev_message_id), + $(prev_message_table), + $(role), + $(content), + $(reasoning_content), + $(inserted_at) + ) + RETURNING id + """ + + message_params = %{ + "conversation_id" => conversation_id, + "prev_message_id" => prev_id, + "prev_message_table" => prev_table, + "role" => "assistant", + "content" => message[:content], + "reasoning_content" => message[:reasoning_content], + "inserted_at" => timestamp + } + + case DbHelpers.run_sql(message_sql, message_params, topic) do + {:error, :db_error} -> + {:error, :db_error} + + [%{"id" => text_message_id}] -> + if message[:tool_calls] && length(message[:tool_calls]) > 0 do + Enum.each(message[:tool_calls], fn tool_call -> + {tc_prev_id, tc_prev_table} = get_last_message_ref(conversation_id, topic) + + tool_call_sql = """ + INSERT INTO tool_calls_request_messages ( + text_message_id, + prev_message_id, + prev_message_table, + tool_name, + tool_call_id, + arguments, + inserted_at + ) VALUES ( + $(text_message_id), + $(prev_message_id), + $(prev_message_table), + $(tool_name), + $(tool_call_id), + $(arguments)::jsonb, + $(inserted_at) + ) + """ + + tool_call_params = %{ + "text_message_id" => text_message_id, + "prev_message_id" => tc_prev_id, + "prev_message_table" => tc_prev_table, + "tool_name" => tool_call[:name] || tool_call["name"], + "tool_call_id" => tool_call[:id] || tool_call["id"], + "arguments" => + encode_tool_call_arguments(tool_call[:arguments] || tool_call["arguments"]), + "inserted_at" => timestamp + } + + DbHelpers.run_sql(tool_call_sql, tool_call_params, topic) + end) + end + + {:ok, 1} + + _ -> + {:error, :db_error} + end + end + + defp insert_tool_response(message, _timestamp, topic) do + # tool_response_messages has no conversation_id, so look up via the tool_call + tool_call_id = message[:tool_call_id] + + {prev_id, prev_table} = get_last_tool_response_ref(tool_call_id, topic) + + sql = """ + INSERT INTO tool_response_messages ( + tool_call_id, + prev_message_id, + prev_message_table, + content + ) VALUES ( + $(tool_call_id), + $(prev_message_id), + $(prev_message_table), + $(content) + ) + """ + + params = %{ + "tool_call_id" => tool_call_id, + "prev_message_id" => prev_id, + "prev_message_table" => prev_table, + "content" => message[:content] || "" + } + + case DbHelpers.run_sql(sql, params, topic) do + {:error, :db_error} -> {:error, :db_error} + _result -> {:ok, 1} + end + end + + # Returns {id, table_name} of the most recently inserted message in the conversation, + # searching text_messages, tool_calls_request_messages, and tool_response_messages. + defp get_last_message_ref(conversation_id, topic) do + sql = """ + SELECT id, 'text_messages' AS tbl, inserted_at + FROM text_messages WHERE conversation_id = $(conversation_id) + UNION ALL + SELECT tc.id, 'tool_calls_request_messages', tc.inserted_at + FROM tool_calls_request_messages tc + JOIN text_messages tm ON tc.text_message_id = tm.id + WHERE tm.conversation_id = $(conversation_id) + UNION ALL + SELECT tr.id, 'tool_response_messages', tr.inserted_at + FROM tool_response_messages tr + JOIN tool_calls_request_messages tc ON tr.tool_call_id = tc.tool_call_id + JOIN text_messages tm ON tc.text_message_id = tm.id + WHERE tm.conversation_id = $(conversation_id) + ORDER BY inserted_at DESC, id DESC + LIMIT 1 + """ + + case DbHelpers.run_sql(sql, %{"conversation_id" => conversation_id}, topic) do + [%{"id" => id, "tbl" => tbl}] -> {id, tbl} + _ -> {nil, nil} + end + end + + defp get_last_tool_response_ref(tool_call_id, topic) do + sql = """ + SELECT tc.id, 'tool_calls_request_messages' AS tbl + FROM tool_calls_request_messages tc + WHERE tc.tool_call_id = $(tool_call_id) + LIMIT 1 + """ + + case DbHelpers.run_sql(sql, %{"tool_call_id" => tool_call_id}, topic) do + [%{"id" => id, "tbl" => tbl}] -> {id, tbl} + _ -> {nil, nil} + end + end + + defp sort_by_prev_message([], _by_key), do: [] + + defp sort_by_prev_message(rows, _by_key) do + # Find the head: the row whose {prev_message_table, prev_message_id} is not in the set, + # i.e. it has no predecessor among this conversation's messages. + keys = MapSet.new(rows, fn r -> {r._table, r.id} end) + + head = + Enum.find(rows, fn r -> + prev_key = {r[:prev_message_table], r[:prev_message_id]} + is_nil(r[:prev_message_id]) or not MapSet.member?(keys, prev_key) + end) + + if is_nil(head) do + rows + else + # Build a reverse index: prev pointer -> row that points to it + by_prev = + Map.new(rows, fn r -> + {{r[:prev_message_table], r[:prev_message_id]}, r} + end) + + Stream.iterate(head, fn r -> + Map.get(by_prev, {r._table, r.id}) + end) + |> Enum.take_while(&(&1 != nil)) + end + end + + defp encode_tool_call_arguments(args) when is_binary(args), do: args + defp encode_tool_call_arguments(args), do: Jason.encode!(args) defp dump_uuid(id) when is_binary(id) and byte_size(id) == 16, do: {:ok, id} defp dump_uuid(id) when is_binary(id), do: Ecto.UUID.dump(id) defp dump_uuid(_), do: :error - defp decode_message(row) do - row - |> Map.new(fn {k, v} -> {String.to_existing_atom(k), v} end) - |> Map.update!(:role, &String.to_existing_atom/1) - |> Map.update(:tool_calls, nil, fn - nil -> - nil - - json when is_binary(json) -> - json |> Jason.decode!() |> Enum.map(&atomize_keys/1) - - already_decoded -> - Enum.map(already_decoded, &atomize_keys/1) - end) - |> drop_nil_fields() - end - - defp atomize_keys(map) when is_map(map) do - Map.new(map, fn - {k, v} when is_binary(k) -> {String.to_atom(k), v} - {k, v} -> {k, v} - end) - end - defp drop_nil_fields(map) do Map.reject(map, fn {_k, v} -> is_nil(v) end) end diff --git a/lib/elixir_ai_web/components/chat_message.ex b/lib/elixir_ai_web/components/chat_message.ex index ca04fcf..a8a6f39 100644 --- a/lib/elixir_ai_web/components/chat_message.ex +++ b/lib/elixir_ai_web/components/chat_message.ex @@ -199,7 +199,12 @@ defmodule ElixirAiWeb.ChatMessage do """ end - # Dispatches to the appropriate tool call component based on result state + # Dispatches to the appropriate tool call component based on result state. + # Four states: + # :error key present → error (runtime failure) + # :result key present → success (runtime completed) + # :index key present → pending (streaming in-progress) + # none of the above → called (DB-loaded completed call; result is a separate message) attr :tool_call, :map, required: true defp tool_call_item(%{tool_call: tool_call} = assigns) do @@ -208,7 +213,7 @@ defmodule ElixirAiWeb.ChatMessage do assigns = assigns |> assign(:name, tool_call.name) - |> assign(:arguments, tool_call[:arguments] || "") + |> assign(:arguments, tool_call[:arguments]) |> assign(:error, tool_call.error) ~H"<.error_tool_call name={@name} arguments={@arguments} error={@error} />" @@ -217,23 +222,61 @@ defmodule ElixirAiWeb.ChatMessage do assigns = assigns |> assign(:name, tool_call.name) - |> assign(:arguments, tool_call[:arguments] || "") + |> assign(:arguments, tool_call[:arguments]) |> assign(:result, tool_call.result) ~H"<.success_tool_call name={@name} arguments={@arguments} result={@result} />" + Map.has_key?(tool_call, :index) -> + assigns = + assigns + |> assign(:name, tool_call.name) + |> assign(:arguments, tool_call[:arguments]) + + ~H"<.pending_tool_call name={@name} arguments={@arguments} />" + true -> assigns = assigns |> assign(:name, tool_call.name) - |> assign(:arguments, tool_call[:arguments] || "") + |> assign(:arguments, tool_call[:arguments]) - ~H"<.pending_tool_call name={@name} arguments={@arguments} />" + ~H"<.called_tool_call name={@name} arguments={@arguments} />" end end attr :name, :string, required: true - attr :arguments, :string, default: "" + attr :arguments, :any, default: nil + + defp called_tool_call(assigns) do + ~H""" +
+
+ <.tool_call_icon /> + {@name} + + + + + called + +
+ <.tool_call_args arguments={@arguments} /> +
+ """ + end + + attr :name, :string, required: true + attr :arguments, :any, default: nil defp pending_tool_call(assigns) do ~H""" @@ -252,7 +295,7 @@ defmodule ElixirAiWeb.ChatMessage do end attr :name, :string, required: true - attr :arguments, :string, default: "" + attr :arguments, :any, default: nil attr :result, :any, required: true defp success_tool_call(assigns) do @@ -297,7 +340,7 @@ defmodule ElixirAiWeb.ChatMessage do end attr :name, :string, required: true - attr :arguments, :string, default: "" + attr :arguments, :any, default: nil attr :error, :string, required: true defp error_tool_call(assigns) do @@ -327,16 +370,22 @@ defmodule ElixirAiWeb.ChatMessage do """ end - attr :arguments, :string, default: "" + attr :arguments, :any, default: nil - defp tool_call_args(%{arguments: args} = assigns) when args != "" do + defp tool_call_args(%{arguments: args} = assigns) when not is_nil(args) and args != "" do assigns = assign( assigns, :pretty_args, - case Jason.decode(args) do - {:ok, decoded} -> Jason.encode!(decoded, pretty: true) - _ -> args + case args do + s when is_binary(s) -> + case Jason.decode(s) do + {:ok, decoded} -> Jason.encode!(decoded, pretty: true) + _ -> s + end + + other -> + Jason.encode!(other, pretty: true) end ) diff --git a/lib/elixir_ai_web/live/chat_live.ex b/lib/elixir_ai_web/live/chat_live.ex index c6898c6..4b236d0 100644 --- a/lib/elixir_ai_web/live/chat_live.ex +++ b/lib/elixir_ai_web/live/chat_live.ex @@ -39,6 +39,18 @@ defmodule ElixirAiWeb.ChatLive do {:error, :not_found} -> {:ok, push_navigate(socket, to: "/")} + + {:error, reason} -> + Logger.error("Failed to start conversation #{name}: #{inspect(reason)}") + + {:ok, + socket + |> assign(conversation_name: name) + |> assign(user_input: "") + |> assign(messages: []) + |> assign(streaming_response: nil) + |> assign(background_color: "bg-cyan-950/30") + |> assign(db_error: Exception.format(:error, reason))} end end @@ -67,14 +79,17 @@ defmodule ElixirAiWeb.ChatLive do <%= for msg <- @messages do %> <%= cond do %> <% msg.role == :user -> %> - <.user_message content={msg.content} /> + <.user_message content={Map.get(msg, :content) || ""} /> <% msg.role == :tool -> %> - <.tool_result_message content={msg.content} tool_call_id={msg.tool_call_id} /> + <.tool_result_message + content={Map.get(msg, :content) || ""} + tool_call_id={Map.get(msg, :tool_call_id) || ""} + /> <% true -> %> <.assistant_message - content={msg.content} - reasoning_content={msg.reasoning_content} - tool_calls={Map.get(msg, :tool_calls, [])} + content={Map.get(msg, :content) || ""} + reasoning_content={Map.get(msg, :reasoning_content)} + tool_calls={Map.get(msg, :tool_calls) || []} /> <% end %> <% end %> diff --git a/postgres/schema/00-schema.sql b/postgres/schema/00-schema.sql new file mode 100644 index 0000000..e882fec --- /dev/null +++ b/postgres/schema/00-schema.sql @@ -0,0 +1,56 @@ +-- Initial schema + +CREATE TABLE IF NOT EXISTS ai_providers ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL UNIQUE, + model_name TEXT NOT NULL, + api_token TEXT NOT NULL, + completions_url TEXT NOT NULL, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS conversations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL UNIQUE, + ai_provider_id UUID NOT NULL REFERENCES ai_providers(id) ON DELETE RESTRICT, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS text_messages ( + id BIGSERIAL PRIMARY KEY, + conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + prev_message_id BIGINT, + prev_message_table TEXT CHECK (prev_message_table IN ('text_messages', 'tool_calls_request_messages', 'tool_response_messages')), + role TEXT NOT NULL CHECK (role IN ('user', 'assistant')), + content TEXT, + reasoning_content TEXT, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS tool_calls_request_messages ( + id BIGSERIAL PRIMARY KEY, + text_message_id BIGINT NOT NULL REFERENCES text_messages(id) ON DELETE CASCADE, + prev_message_id BIGINT, + prev_message_table TEXT CHECK (prev_message_table IN ('text_messages', 'tool_calls_request_messages', 'tool_response_messages')), + tool_name TEXT NOT NULL, + tool_call_id TEXT NOT NULL UNIQUE, + arguments JSONB NOT NULL, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS tool_response_messages ( + id BIGSERIAL PRIMARY KEY, + tool_call_id TEXT NOT NULL REFERENCES tool_calls_request_messages(tool_call_id) ON DELETE CASCADE, + prev_message_id BIGINT, + prev_message_table TEXT CHECK (prev_message_table IN ('text_messages', 'tool_calls_request_messages', 'tool_response_messages')), + content TEXT NOT NULL, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_text_messages_prev ON text_messages(prev_message_id); +CREATE INDEX IF NOT EXISTS idx_tool_call_msgs_prev ON tool_calls_request_messages(prev_message_id); +CREATE INDEX IF NOT EXISTS idx_tool_call_msgs_text_msg ON tool_calls_request_messages(text_message_id); +CREATE INDEX IF NOT EXISTS idx_tool_call_msgs_tool_call_id ON tool_calls_request_messages(tool_call_id); +CREATE INDEX IF NOT EXISTS idx_tool_response_msgs_prev ON tool_response_messages(prev_message_id); diff --git a/schema.sql b/schema.sql deleted file mode 100644 index 3d87b1c..0000000 --- a/schema.sql +++ /dev/null @@ -1,33 +0,0 @@ --- drop table if exists messages cascade; --- drop table if exists conversations cascade; --- drop table if exists ai_providers cascade; - - -CREATE TABLE IF NOT EXISTS ai_providers ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - name TEXT NOT NULL UNIQUE, - model_name TEXT NOT NULL, - api_token TEXT NOT NULL, - completions_url TEXT NOT NULL, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE IF NOT EXISTS conversations ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - name TEXT NOT NULL UNIQUE, - ai_provider_id UUID NOT NULL REFERENCES ai_providers(id) ON DELETE RESTRICT, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE IF NOT EXISTS messages ( - id BIGSERIAL PRIMARY KEY, - conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, - role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'tool')), - content TEXT, - reasoning_content TEXT, - tool_calls JSONB, - tool_call_id TEXT, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); diff --git a/test/message_storage_test.exs b/test/message_storage_test.exs index 3cc0e0b..156403e 100644 --- a/test/message_storage_test.exs +++ b/test/message_storage_test.exs @@ -29,10 +29,41 @@ defmodule ElixirAi.MessageStorageTest do String.contains?(sql, "SELECT id FROM conversations") -> [%{"id" => conv_id}] + String.contains?(sql, "SELECT") and String.contains?(sql, "FROM messages m") and + String.contains?(sql, "LEFT JOIN assistant_message_details") -> + # Load messages query + [] + + String.contains?(sql, "SELECT") and String.contains?(sql, "FROM tool_calls") -> + # Load tool calls query + [] + + String.contains?(sql, "SELECT") and String.contains?(sql, "FROM tool_responses") -> + # Load tool responses query + [] + + String.contains?(sql, "INSERT INTO messages") and String.contains?(sql, "RETURNING id") -> + # Assistant message insert - return a fake message_id + send(test_pid, {:insert_assistant_message, params}) + [%{"id" => 123}] + String.contains?(sql, "INSERT INTO messages") -> + # User message insert send(test_pid, {:insert_message, params}) [] + String.contains?(sql, "INSERT INTO tool_calls") -> + send(test_pid, {:insert_tool_call, params}) + [] + + String.contains?(sql, "INSERT INTO tool_responses") -> + send(test_pid, {:insert_tool_response, params}) + [] + + String.contains?(sql, "INSERT INTO assistant_message_details") -> + send(test_pid, {:insert_assistant_details, params}) + [] + true -> [] end @@ -74,7 +105,7 @@ defmodule ElixirAi.MessageStorageTest do ElixirAi.ChatRunner.new_user_message(conv_name, "hi") assert_receive {:insert_message, %{"role" => "user"}}, 2000 - assert_receive {:insert_message, params}, 2000 + assert_receive {:insert_assistant_message, params}, 2000 assert params["role"] == "assistant" assert params["content"] == "Hello from AI" end @@ -104,14 +135,17 @@ defmodule ElixirAi.MessageStorageTest do assert_receive {:insert_message, %{"role" => "user"}}, 2000 - # Assistant message that carries the tool_calls list - assert_receive {:insert_message, params}, 2000 + # Assistant message with tool_calls + assert_receive {:insert_assistant_message, params}, 2000 assert params["role"] == "assistant" - refute is_nil(params["tool_calls"]) - # Tool result message - assert_receive {:insert_message, params}, 2000 - assert params["role"] == "tool" + # Tool call details stored separately + assert_receive {:insert_tool_call, params}, 2000 + assert params["tool_name"] == "store_thing" + assert params["tool_call_id"] == "tc_1" + + # Tool result stored in tool_responses table + assert_receive {:insert_tool_response, params}, 2000 assert params["tool_call_id"] == "tc_1" end end