-
Notifications
You must be signed in to change notification settings - Fork 271
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add GPT-4o-mini and Llama 3.3 70B on Stanford Health Care API (#3277)
- Loading branch information
Showing
6 changed files
with
163 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
from typing import Dict, Optional | ||
|
||
from helm.clients.openai_client import OpenAIClient | ||
from helm.common.cache import CacheConfig | ||
from helm.common.optional_dependencies import handle_module_not_found_error | ||
from helm.proxy.retry import NonRetriableException | ||
from helm.tokenizers.tokenizer import Tokenizer | ||
|
||
try: | ||
from openai import AzureOpenAI | ||
except ModuleNotFoundError as e: | ||
handle_module_not_found_error(e, ["openai"]) | ||
|
||
|
||
class AzureOpenAIClient(OpenAIClient): | ||
API_VERSION = "2024-07-01-preview" | ||
|
||
def __init__( | ||
self, | ||
tokenizer: Tokenizer, | ||
tokenizer_name: str, | ||
cache_config: CacheConfig, | ||
api_key: Optional[str] = None, | ||
endpoint: Optional[str] = None, | ||
api_version: Optional[str] = None, | ||
default_headers: Optional[Dict[str, str]] = None, | ||
): | ||
super().__init__( | ||
tokenizer=tokenizer, tokenizer_name=tokenizer_name, cache_config=cache_config, api_key="unused" | ||
) | ||
azure_endpoint = endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") | ||
if not azure_endpoint: | ||
raise NonRetriableException("Must provide Azure endpoint through credentials.conf or AZURE_OPENAI_ENDPOINT") | ||
self.client = AzureOpenAI( | ||
api_key=api_key, | ||
api_version=api_version or AzureOpenAIClient.API_VERSION, | ||
azure_endpoint=azure_endpoint, | ||
default_headers=default_headers, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Optional | ||
|
||
from helm.clients.openai_client import OpenAIClient | ||
from helm.common.cache import CacheConfig | ||
from helm.common.optional_dependencies import handle_module_not_found_error | ||
from helm.proxy.retry import NonRetriableException | ||
from helm.tokenizers.tokenizer import Tokenizer | ||
|
||
try: | ||
from openai import OpenAI | ||
except ModuleNotFoundError as e: | ||
handle_module_not_found_error(e, ["openai"]) | ||
|
||
|
||
class StanfordHealthCareLlamaClient(OpenAIClient): | ||
""" | ||
Client for accessing Llama models hosted on Stanford Health Care's model API. | ||
Configure by setting the following in prod_env/credentials.conf: | ||
``` | ||
stanfordhealthcareEndpoint: https://your-domain-name/ | ||
stanfordhealthcareApiKey: your-private-key | ||
``` | ||
""" | ||
|
||
CREDENTIAL_HEADER_NAME = "Ocp-Apim-Subscription-Key" | ||
|
||
def __init__( | ||
self, | ||
tokenizer: Tokenizer, | ||
tokenizer_name: str, | ||
cache_config: CacheConfig, | ||
model_name: str, | ||
api_key: Optional[str] = None, | ||
endpoint: Optional[str] = None, | ||
): | ||
super().__init__( | ||
tokenizer=tokenizer, tokenizer_name=tokenizer_name, cache_config=cache_config, api_key="unused" | ||
) | ||
if not endpoint: | ||
raise NonRetriableException("Must provide endpoint through credentials.conf") | ||
if not api_key: | ||
raise NonRetriableException("Must provide API key through credentials.conf") | ||
# Guess the base URL part based on the model name | ||
# Maybe make this configurable instead? | ||
base_url_part = model_name.split("/")[1].lower().removesuffix("-instruct").replace("-", "").replace(".", "") | ||
|
||
base_url = f"{endpoint.strip('/')}/{base_url_part}/v1" | ||
self.client = OpenAI( | ||
api_key="dummy", | ||
base_url=base_url, | ||
default_headers={StanfordHealthCareLlamaClient.CREDENTIAL_HEADER_NAME: api_key}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Optional | ||
|
||
from helm.clients.azure_openai_client import AzureOpenAIClient | ||
from helm.common.cache import CacheConfig | ||
from helm.proxy.retry import NonRetriableException | ||
from helm.tokenizers.tokenizer import Tokenizer | ||
|
||
|
||
class StanfordHealthCareOpenAIClient(AzureOpenAIClient): | ||
""" | ||
Client for accessing OpenAI models hosted on Stanford Health Care's model API. | ||
Configure by setting the following in prod_env/credentials.conf: | ||
``` | ||
stanfordhealthcareEndpoint: https://your-domain-name/ | ||
stanfordhealthcareApiKey: your-private-key | ||
``` | ||
""" | ||
|
||
API_VERSION = "2023-05-15" | ||
CREDENTIAL_HEADER_NAME = "Ocp-Apim-Subscription-Key" | ||
|
||
def __init__( | ||
self, | ||
tokenizer: Tokenizer, | ||
tokenizer_name: str, | ||
cache_config: CacheConfig, | ||
api_key: Optional[str] = None, | ||
endpoint: Optional[str] = None, | ||
): | ||
if not api_key: | ||
raise NonRetriableException("Must provide API key through credentials.conf") | ||
super().__init__( | ||
tokenizer=tokenizer, | ||
tokenizer_name=tokenizer_name, | ||
cache_config=cache_config, | ||
api_key="unused", | ||
endpoint=endpoint, | ||
api_version=StanfordHealthCareOpenAIClient.API_VERSION, | ||
default_headers={StanfordHealthCareOpenAIClient.CREDENTIAL_HEADER_NAME: api_key}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters