From 779c7b97d82ac9f1255b150db837dace1ccc0190 Mon Sep 17 00:00:00 2001 From: Vijay Viswanathan Date: Tue, 12 Sep 2023 13:47:56 -0400 Subject: [PATCH] Support HuggingFace's inference API (#352) * Support HuggingFace's inference API * Use positive temperature for prompt parser --- prompt2model/prompt_parser/instr_parser.py | 2 +- prompt2model/utils/api_tools.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index 81125487f..d6d18a0d3 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -93,7 +93,7 @@ def parse_from_prompt(self, prompt: str) -> None: response: openai.ChatCompletion | Exception = ( chat_api.generate_one_completion( parsing_prompt_for_chatgpt, - temperature=0, + temperature=0.01, presence_penalty=0, frequency_penalty=0, ) diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index d7143711a..bce8a1dd7 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -45,6 +45,7 @@ def __init__( self, model_name: str = "gpt-3.5-turbo", max_tokens: int | None = None, + api_base: str | None = None, ): """Initialize APIAgent with model_name and max_tokens. @@ -52,9 +53,11 @@ def __init__( model_name: Name fo the model to use (by default, gpt-3.5-turbo). max_tokens: The maximum number of tokens to generate. Defaults to the max value for the model if available through litellm. + api_base: Custom endpoint for Hugging Face's inference API. """ self.model_name = model_name self.max_tokens = max_tokens + self.api_base = api_base if max_tokens is None: try: self.max_tokens = litellm.utils.get_max_tokens(model_name) @@ -99,6 +102,7 @@ def generate_one_completion( messages=[ {"role": "user", "content": f"{prompt}"}, ], + api_base=self.api_base, temperature=temperature, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, @@ -144,6 +148,7 @@ async def _throttled_completion_acreate( return await acompletion( model=model, messages=messages, + api_base=self.api_base, temperature=temperature, max_tokens=max_tokens, n=n,