Skip to content

Commit

Permalink
Fix model name error in finetune (stanfordnlp#7880)
Browse files Browse the repository at this point in the history
* fix finetune model name error

* remove unnecessary nonlocal
  • Loading branch information
TomeHirata authored Feb 28, 2025
1 parent 893fc27 commit 79b8808
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 25 deletions.
4 changes: 2 additions & 2 deletions docs/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ Given a few tens or hundreds of representative _inputs_ of your task and a _metr

```python linenums="1"
import dspy
dspy.configure(lm=dspy.LM('gpt-4o-mini-2024-07-18'))
dspy.configure(lm=dspy.LM('openai/gpt-4o-mini-2024-07-18'))
# Define the DSPy module for classification. It will use the hint at training time, if available.
signature = dspy.Signature("text -> label").with_updated_fields('label', type_=Literal[tuple(CLASSES)])
Expand All @@ -394,7 +394,7 @@ Given a few tens or hundreds of representative _inputs_ of your task and a _metr
optimizer = dspy.BootstrapFinetune(metric=(lambda x, y, trace=None: x.label == y.label), num_threads=24)
optimized = optimizer.compile(classify, trainset=trainset)

optimized_classifier(text="What does a pending cash withdrawal mean?")
optimized(text="What does a pending cash withdrawal mean?")
```

**Possible Output (from the last line):**
Expand Down
16 changes: 8 additions & 8 deletions dspy/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,21 @@ def finetune(
train_data_format: Optional[Union[TrainDataFormat, str]] = "chat",
train_kwargs: Optional[Dict[str, Any]] = None,
) -> str:
if isinstance(data_format, str):
if data_format == "chat":
data_format = TrainDataFormat.CHAT
elif data_format == "completion":
data_format = TrainDataFormat.COMPLETION
if isinstance(train_data_format, str):
if train_data_format == "chat":
train_data_format = TrainDataFormat.CHAT
elif train_data_format == "completion":
train_data_format = TrainDataFormat.COMPLETION
else:
raise ValueError(
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {data_format}."
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {train_data_format}."
)

if "train_data_path" not in train_kwargs:
raise ValueError("The `train_data_path` must be provided to finetune on Databricks.")
# Add the file name to the directory path.
train_kwargs["train_data_path"] = DatabricksProvider.upload_data(
train_data, train_kwargs["train_data_path"], data_format
train_data, train_kwargs["train_data_path"], train_data_format
)

try:
Expand Down Expand Up @@ -236,7 +236,7 @@ def finetune(
model_to_deploy = train_kwargs.get("register_to")
job.endpoint_name = model_to_deploy.replace(".", "_")
DatabricksProvider.deploy_finetuned_model(
model_to_deploy, data_format, databricks_host, databricks_token, deploy_timeout
model_to_deploy, train_data_format, databricks_host, databricks_token, deploy_timeout
)
job.launch_completed = True
# The finetuned model name should be in the format: "databricks/<endpoint_name>".
Expand Down
8 changes: 6 additions & 2 deletions dspy/clients/lm_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from typing import Any, Dict, List, Optional
from dspy.clients.provider import TrainingJob, Provider
from dspy.clients.utils_finetune import TrainDataFormat, save_data
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dspy.clients.lm import LM

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,7 +122,7 @@ def get_logs() -> str:


@staticmethod
def kill(lm: 'LM', launch_kwargs: Optional[Dict[str, Any]] = None):
def kill(lm: "LM", launch_kwargs: Optional[Dict[str, Any]] = None):
from sglang.utils import terminate_process
if not hasattr(lm, "process"):
logger.info("No running server to kill.")
Expand Down Expand Up @@ -227,7 +231,7 @@ def train_sft_locally(model_name, train_data, train_kwargs):

hf_dataset = Dataset.from_list(train_data)
def tokenize_function(example):
return encode_sft_example(example, tokenizer, max_seq_length)
return encode_sft_example(example, tokenizer, train_kwargs["max_seq_length"])
tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
tokenized_dataset.set_format(type="torch")
tokenized_dataset = tokenized_dataset.filter(lambda example: (example["labels"] != -100).any())
Expand Down
23 changes: 13 additions & 10 deletions dspy/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,13 @@ def __init__(self):

@staticmethod
def is_provider_model(model: str) -> bool:
# Filter the provider_prefix, if exists
provider_prefix = "openai/"
if model.startswith(provider_prefix):
model = model[len(provider_prefix):]
model = OpenAIProvider._remove_provider_prefix(model)

# Check if the model is a base OpenAI model
# TODO(enhance) The following list can be replaced with
# openai.models.list(), but doing so might require a key. Is there a
# way to get the list of models without a key?
valid_model_names = _OPENAI_MODELS
if model in valid_model_names:
if model in _OPENAI_MODELS:
return True

# Check if the model is a fine-tuned OpneAI model. Fine-tuned OpenAI
Expand All @@ -113,10 +109,15 @@ def is_provider_model(model: str) -> bool:
# model names by making a call to the OpenAI API to be more exact, but
# this might require an API key with the right permissions.
match = re.match(r"ft:([^:]+):", model)
if match and match.group(1) in valid_model_names:
if match and match.group(1) in _OPENAI_MODELS:
return True

return False

@staticmethod
def _remove_provider_prefix(model: str) -> str:
provider_prefix = "openai/"
return model.replace(provider_prefix, "")

@staticmethod
def finetune(
Expand All @@ -126,6 +127,8 @@ def finetune(
train_data_format: Optional[TrainDataFormat],
train_kwargs: Optional[Dict[str, Any]] = None,
) -> str:
model = OpenAIProvider._remove_provider_prefix(model)

print("[OpenAI Provider] Validating the data format")
OpenAIProvider.validate_data_format(train_data_format)

Expand All @@ -138,7 +141,7 @@ def finetune(
job.provider_file_id = provider_file_id

print("[OpenAI Provider] Starting remote training")
provider_job_id = OpenAIProvider.start_remote_training(
provider_job_id = OpenAIProvider._start_remote_training(
train_file_id=job.provider_file_id,
model=model,
train_kwargs=train_kwargs,
Expand Down Expand Up @@ -231,9 +234,9 @@ def upload_data(data_path: str) -> str:
return provider_file.id

@staticmethod
def start_remote_training(
def _start_remote_training(
train_file_id: str,
model: id,
model: str,
train_kwargs: Optional[Dict[str, Any]] = None
) -> str:
train_kwargs = train_kwargs or {}
Expand Down
7 changes: 5 additions & 2 deletions dspy/clients/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from typing import Any, Dict, List, Optional, Union

from dspy.clients.utils_finetune import TrainDataFormat
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dspy.clients.lm import LM

class TrainingJob(Future):
def __init__(
Expand Down Expand Up @@ -44,14 +47,14 @@ def is_provider_model(model: str) -> bool:
return False

@staticmethod
def launch(lm: 'LM', launch_kwargs: Optional[Dict[str, Any]] = None):
def launch(lm: "LM", launch_kwargs: Optional[Dict[str, Any]] = None):
# Note that "launch" and "kill" methods might be called even if there
# is a launched LM or no launched LM to kill. These methods should be
# resillient to such cases.
pass

@staticmethod
def kill(lm: 'LM', launch_kwargs: Optional[Dict[str, Any]] = None):
def kill(lm: "LM", launch_kwargs: Optional[Dict[str, Any]] = None):
# We assume that LM.launch_kwargs dictionary will contain the necessary
# information for a provider to launch and/or kill an LM. This is the
# reeason why the argument here is named launch_kwargs and not
Expand Down
2 changes: 1 addition & 1 deletion dspy/teleprompt/bootstrap_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def finetune_lms(finetune_dict) -> Dict[Any, LM]:

key_to_job = {}
for key, finetune_kwargs in finetune_dict.items():
lm = finetune_kwargs.pop("lm")
lm: LM = finetune_kwargs.pop("lm")
# TODO: The following line is a hack. We should re-think how to free
# up resources for fine-tuning. This might mean introducing a new
# provider method (e.g. prepare_for_finetune) that can be called
Expand Down

0 comments on commit 79b8808

Please sign in to comment.