Skip to content

Commit

Permalink
Add LLAMA 3.1 Json tool call with Bumblebee (#198)
Browse files Browse the repository at this point in the history
* Add LLAMA 3.1 Json tool call

* add Example and remove unused var warning

* fix review findings Readme and debug print

* run mix format

* Add LLAMA custom tool calling

* Remove IO.inspect

* fix prep_and_validate_messages

* fix merg error

* remove IO.inspect

* Fix empty tool and Test current data

* Add Code to make Parsec Optional
  • Loading branch information
marcnnn authored Jan 22, 2025
1 parent 9cdfe23 commit f1afe1a
Show file tree
Hide file tree
Showing 10 changed files with 1,345 additions and 10 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ For example, if a locally running service provided that feature, the following c

Bumblebee hosted chat models are supported. There is built-in support for Llama 2, Mistral, and Zephyr models.

Currently, function calling is NOT supported with these models.
Currently, function calling is only supported for llama 3.1 Json Tool calling for Llama 2, Mistral, and Zephyr is NOT supported.
There is an example notebook in the notebook folder.

ChatBumblebee.new!(%{
serving: @serving_name,
Expand Down
219 changes: 211 additions & 8 deletions lib/chat_models/chat_bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ defmodule LangChain.ChatModels.ChatBumblebee do
alias LangChain.MessageDelta
alias LangChain.Utils.ChatTemplates
alias LangChain.Callbacks
alias LangChain.Message.ToolCall
alias Langchain.Utils.Parser.LLAMA_3_1_CustomToolParser
alias Langchain.Utils.Parser.LLAMA_3_2_CustomToolParser

@behaviour ChatModel

Expand All @@ -123,14 +126,18 @@ defmodule LangChain.ChatModels.ChatBumblebee do
# # more focused and deterministic.
# field :temperature, :float, default: 1.0

field :template_format, Ecto.Enum, values: [
:inst,
:im_start,
:zephyr,
:phi_4,
:llama_2,
:llama_3
]
field :template_format, Ecto.Enum,
values: [
:inst,
:im_start,
:zephyr,
:phi_4,
:llama_2,
:llama_3,
:llama_3_1_json_tool_calling,
:llama_3_1_custom_tool_calling,
:llama_3_2_custom_tool_calling
]

# The bumblebee model may compile differently based on the stream true/false
# option on the serving. Therefore, streaming should be enabled on the
Expand Down Expand Up @@ -242,6 +249,46 @@ defmodule LangChain.ChatModels.ChatBumblebee do
@doc false
@spec do_serving_request(t(), [Message.t()], [Function.t()]) ::
list() | struct() | {:error, String.t()}

def do_serving_request(
%ChatBumblebee{template_format: :llama_3_1_json_tool_calling} = model,
messages,
functions
) do
prompt =
ChatTemplates.apply_chat_template_with_tools!(messages, model.template_format, functions)

model.serving
|> Nx.Serving.batched_run(%{text: prompt, seed: model.seed})
|> do_process_response(model)
end

def do_serving_request(
%ChatBumblebee{template_format: :llama_3_1_custom_tool_calling} = model,
messages,
functions
) do
prompt =
ChatTemplates.apply_chat_template_with_tools!(messages, model.template_format, functions)

model.serving
|> Nx.Serving.batched_run(%{text: prompt, seed: model.seed})
|> do_process_response(model)
end

def do_serving_request(
%ChatBumblebee{template_format: :llama_3_2_custom_tool_calling} = model,
messages,
functions
) do
prompt =
ChatTemplates.apply_chat_template_with_tools!(messages, model.template_format, functions)

model.serving
|> Nx.Serving.batched_run(%{text: prompt, seed: model.seed})
|> do_process_response(model)
end

def do_serving_request(%ChatBumblebee{} = model, messages, _functions) do
prompt = ChatTemplates.apply_chat_template!(messages, model.template_format)

Expand All @@ -250,7 +297,163 @@ defmodule LangChain.ChatModels.ChatBumblebee do
|> do_process_response(model)
end

def do_process_response(
%{results: [%{text: "[" <> _ = content, token_summary: token_summary}]},
%ChatBumblebee{template_format: :llama_3_2_custom_tool_calling} = model
)
when is_binary(content) do
if !Code.ensure_loaded?(NimbleParsec) do
raise "Install NimbleParsec to use custom tool calling"
end

fire_token_usage_callback(model, token_summary)

case LLAMA_3_2_CustomToolParser.parse(content) do
{:ok, functions} ->
case Message.new(%{
role: :assistant,
status: :complete,
content: content,
tool_calls:
Enum.with_index(functions, fn i,
%{
function_name: name,
parameters: parameters
} ->
ToolCall.new!(%{
call_id: Integer.to_string(i),
name: name,
arguments: parameters
})
end)
}) do
{:ok, message} ->
# execute the callback with the final message
Callbacks.fire(model.callbacks, :on_llm_new_message, [model, message])
# return a list of the complete message. As a list for compatibility.
[message]

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)
Logger.error("Failed to create non-streamed full message: #{inspect(reason)}")
{:error, reason}
end

{:error, _} ->
case Message.new(%{role: :assistant, status: :complete, content: content}) do
{:ok, message} ->
# execute the callback with the final message
Callbacks.fire(model.callbacks, :on_llm_new_message, [model, message])
# return a list of the complete message. As a list for compatibility.
[message]

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)
Logger.error("Failed to create non-streamed full message: #{inspect(reason)}")
{:error, reason}
end
end
end

def do_process_response(
%{results: [%{text: "<" <> _ = content, token_summary: token_summary}]},
%ChatBumblebee{template_format: :llama_3_1_custom_tool_calling} = model
)
when is_binary(content) do
if !Code.ensure_loaded?(NimbleParsec) do
raise "Install NimbleParsec to use custom tool calling"
end

fire_token_usage_callback(model, token_summary)

case LLAMA_3_1_CustomToolParser.parse(content) do
{:ok,
%{
function_name: name,
parameters: parameters
}} ->
case Message.new(%{
role: :assistant,
status: :complete,
content: content,
tool_calls: [ToolCall.new!(%{call_id: "test", name: name, arguments: parameters})]
}) do
{:ok, message} ->
# execute the callback with the final message
Callbacks.fire(model.callbacks, :on_llm_new_message, [model, message])
# return a list of the complete message. As a list for compatibility.
[message]

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)
Logger.error("Failed to create non-streamed full message: #{inspect(reason)}")
{:error, reason}
end

{:error, _} ->
case Message.new(%{role: :assistant, status: :complete, content: content}) do
{:ok, message} ->
# execute the callback with the final message
Callbacks.fire(model.callbacks, :on_llm_new_message, [model, message])
# return a list of the complete message. As a list for compatibility.
[message]

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)
Logger.error("Failed to create non-streamed full message: #{inspect(reason)}")
{:error, reason}
end
end
end

@doc false
def do_process_response(
%{results: [%{text: "{" <> _ = content, token_summary: token_summary}]},
%ChatBumblebee{template_format: :llama_3_1_json_tool_calling} = model
)
when is_binary(content) do
fire_token_usage_callback(model, token_summary)

case Jason.decode(content) do
{:ok,
%{
"name" => name,
"parameters" => parameters
}} ->
case Message.new(%{
role: :assistant,
status: :complete,
content: content,
tool_calls: [ToolCall.new!(%{call_id: "test", name: name, arguments: parameters})]
}) do
{:ok, message} ->
# execute the callback with the final message
Callbacks.fire(model.callbacks, :on_llm_new_message, [model, message])
# return a list of the complete message. As a list for compatibility.
[message]

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)
Logger.error("Failed to create non-streamed full message: #{inspect(reason)}")
{:error, reason}
end

{:error, _} ->
case Message.new(%{role: :assistant, status: :complete, content: content}) do
{:ok, message} ->
# execute the callback with the final message
Callbacks.fire(model.callbacks, :on_llm_new_message, [model, message])
# return a list of the complete message. As a list for compatibility.
[message]

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)
Logger.error("Failed to create non-streamed full message: #{inspect(reason)}")
{:error, reason}
end
end
end

def do_process_response(
%{results: [%{text: content, token_summary: token_summary}]},
%ChatBumblebee{} = model
Expand Down
Loading

0 comments on commit f1afe1a

Please sign in to comment.