Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tune updates #250

Merged
merged 3 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions moondream/finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ wget https://huggingface.co/vikhyatk/moondream2/resolve/main/model.safetensors

We use Weights & Biases (wandb) to track finetuning progress.

To set it up to track your runs, use `wandb login`
To set it up to track your runs, use `wandb login`.

This will take you through creating an account if you don't have one setup already. Enter your API key and you're ready to go.

Expand All @@ -58,7 +58,7 @@ We return a more detailed caption of the image then you would get from the base
``` # Add save path
save_file(
model.state_dict(),
"", // update this line ex: "models/moondream_text_finetuned.safetensors"
"moondream_finetune.safetensors", // update this line ex: "models/moondream_text_finetuned.safetensors"
)
```

Expand All @@ -67,7 +67,7 @@ We return a more detailed caption of the image then you would get from the base
python -m moondream.finetune.finetune_text
```

The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_text_finetuned.safetensors`
The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_text_finetuned.safetensors`.

### Test the Finetuned Text Encoder

Expand All @@ -87,25 +87,25 @@ For this example, we will be teaching Moondream to detect railroad cracks in ima
Our dataset trains our model such that,

Given the prompt:
`\n\nDetect: crack\n\n`
`\n\nDetect: <class_name>\n\n`

We are returned the coordinates of a detected crack in the following format:
```{'objects': [{'x_min': [X_MIN], 'y_min': [Y_MIN], 'x_max': [X_MAX], 'y_max': [Y_MAX]}]}```

### Setup Dataset Dependencies

1. Visit https://universe.roboflow.com/research-zwl99/railwayvision
2. Download dataset in COCO JSON format into relevant directory (ex: `datasets`)
3. Update path to `annotation_file` (line 169) & `img_dir` (line 170) in `finetune_region.py` to point at the dataset
- `annotation_file` should point to `<dataset_directory>/train/_annotations.coco.json`
- `img_dir` should point to `<dataset_directory>/train/`
4. Double check that you've updated MODEL_PATH to point to the base moondream model in `moondream/finetune/finetune_region.py`
1. Visit https://universe.roboflow.com/waste-detection-l4m9b/waste-detection-ttdir
2. Download dataset in COCO JSON format into relevant directory (ex: `datasets`).
3. Update path to `annotation_file` (line 175) & `img_dir` (line 176) in `finetune_region.py` to point at the dataset.
- `annotation_file` should point to `<dataset_directory>/train/_annotations.coco.json`.
- `img_dir` should point to `<dataset_directory>/train/`.
4. Double check that you've updated MODEL_PATH to point to the base moondream model.
5. Double check that the save path ends in `.safetensors`, otherwise the run will fail.
> Navigate to line 262 in `moondream/finetune/finetune_region.py`
> Navigate to line 287 in `moondream/finetune/finetune_region.py`.
``` # Add save path
save_file(
model.state_dict(),
"", // update this line ex: "models/moondream_region_finetuned.safetensors"
"moondream_finetune.safetensors", // update this line ex: "models/moondream_region_finetuned.safetensors"
)
```

Expand All @@ -114,4 +114,4 @@ We are returned the coordinates of a detected crack in the following format:
python -m moondream.finetune.finetune_region
```

The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_region_finetuned.safetensors`
The process will output a finetuned version of Moondream into your save path. Example output: `models/moondream_region_finetuned.safetensors`.
110 changes: 67 additions & 43 deletions moondream/finetune/finetune_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import wandb
import random

from ..torch.weights import load_weights_into_model
from ..torch.moondream import MoondreamModel, MoondreamConfig, text_encoder
from ..torch.text import _produce_hidden
from ..torch.region import (
Expand All @@ -21,14 +22,15 @@
encode_size,
)


# This is a intended to be a basic starting point. Your optimal hyperparams and data may be different.
MODEL_PATH = ""
# Your data should end with the eos token. Here is the textual representation.
ANSWER_EOS = "<|endoftext|>"
LR = 5e-5
LR = 3e-5
EPOCHS = 1
GRAD_ACCUM_STEPS = 64

random.seed(111)


def lr_schedule(step, max_steps):
x = step / max_steps
Expand Down Expand Up @@ -70,7 +72,6 @@ def region_loss(
class CocoDataset(Dataset):
"""
Dataset class for COCO type data.
To download the Roboflow railwayvision dataset visit: https://universe.roboflow.com/research-zwl99/railwayvision
Make sure to use COCO JSON format.
"""

Expand All @@ -84,6 +85,8 @@ def __init__(self, annotation_file, img_dir, transform=None):

self.images = data["images"]
self.annotations = data["annotations"]
self.categories = data.get("categories", [])
self.class_id_to_name = {cat["id"]: cat["name"] for cat in self.categories}

self.img_id_to_anns = {}
for ann in self.annotations:
Expand Down Expand Up @@ -117,7 +120,7 @@ def __getitem__(self, idx):
ann_list = self.img_id_to_anns[image_id]

boxes = []

class_names = []
for ann in ann_list:
bbox = ann["bbox"]
boxes.append(
Expand All @@ -128,13 +131,16 @@ def __getitem__(self, idx):
bbox[3] / height,
]
)
category_id = ann.get("category_id")
class_names.append(self.class_id_to_name.get(category_id, "unknown"))

boxes = torch.as_tensor(boxes, dtype=torch.float16)

return {
"image": image,
"boxes": boxes,
"image_id": torch.tensor([image_id], dtype=torch.int64),
"class_names": class_names,
}


Expand All @@ -155,12 +161,10 @@ def main():

config = MoondreamConfig()
model = MoondreamModel(config)
# load_weights_into_model(MODEL_PATH, model)
load_weights_into_model(MODEL_PATH, model)

optimizer = AdamW8bit(
[
{"params": model.region.parameters()},
],
[{"params": model.region.parameters()}],
lr=LR,
betas=(0.9, 0.95),
eps=1e-6,
Expand All @@ -178,38 +182,42 @@ def main():
i = 0
for epoch in range(EPOCHS):
for sample in dataset:
i += 1

with torch.no_grad():
i += 1
img_emb = model._run_vision_encoder(sample["image"])
bos_emb = text_encoder(
torch.tensor(
[[model.config.tokenizer.bos_id]], device=model.device
),
model.text,
)

# Basic prompt to detect a crack in the railway tracks
instruction = "\n\nDetect: crack\n\n"
instruction_tokens = model.tokenizer.encode(instruction).ids
instruction_emb = text_encoder(
torch.tensor([[instruction_tokens]], device=model.device),
model.text,
).squeeze(0)

eos_token = model.tokenizer.encode(ANSWER_EOS).ids
eos_emb = text_encoder(
torch.tensor([eos_token], device=model.device),
torch.tensor(
[[model.config.tokenizer.eos_id]], device=model.device
),
model.text,
)

boxes_by_class = {}
for box, cls in zip(sample["boxes"], sample["class_names"]):
boxes_by_class.setdefault(cls, []).append(box)

total_loss = 0.0
for class_name, boxes_list in boxes_by_class.items():
with torch.no_grad():
instruction = f"\n\nDetect: {class_name}\n\n"
instruction_tokens = model.tokenizer.encode(instruction).ids
instruction_emb = text_encoder(
torch.tensor([[instruction_tokens]], device=model.device),
model.text,
).squeeze(0)

cs_emb = []
cs_labels = []
c_idx = []
s_idx = []
if len(sample["boxes"]) > 1:
pass

for bb in sample["boxes"]:
for bb in boxes_list:
l_cs = len(cs_emb)
cs_emb.extend(
[
Expand All @@ -221,9 +229,14 @@ def main():
c_idx.extend([l_cs, l_cs + 1])
s_idx.append(l_cs + 2)
cs_labels.extend(
[min(max(torch.round(p * 1023), 0), 1023) for p in bb]
[
int(torch.clamp(torch.round(p * 1023), 0, 1023).item())
for p in bb
]
)

if len(cs_emb) == 0:
continue
cs_emb = torch.stack(cs_emb)

inputs_embeds = torch.cat(
Expand All @@ -234,43 +247,54 @@ def main():
c_idx = torch.tensor(c_idx) + prefix
s_idx = torch.tensor(s_idx) + prefix

hidden = _produce_hidden(
inputs_embeds=inputs_embeds, w=model.text, config=config.text
)
hidden = _produce_hidden(
inputs_embeds=inputs_embeds, w=model.text, config=config.text
)

loss = region_loss(
hidden_states=hidden,
w=model.region,
labels=torch.stack(cs_labels).to(torch.int64),
c_idx=c_idx,
s_idx=s_idx,
)
loss = region_loss(
hidden_states=hidden,
w=model.region,
labels=torch.tensor(cs_labels, dtype=torch.int64),
c_idx=c_idx,
s_idx=s_idx,
)
total_loss += loss

loss.backward()
total_loss.backward()

if i % GRAD_ACCUM_STEPS == 0:
optimizer.step()
optimizer.zero_grad()

lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
lr_val = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
pbar.set_postfix({"step": i // GRAD_ACCUM_STEPS, "loss": loss.item()})
param_group["lr"] = lr_val
pbar.set_postfix(
{"step": i // GRAD_ACCUM_STEPS, "loss": total_loss.item()}
)
pbar.update(1)
wandb.log(
{"loss/train": loss.item(), "lr": optimizer.param_groups[0]["lr"]}
{
"loss/train": total_loss.item(),
"lr": optimizer.param_groups[0]["lr"],
}
)
wandb.finish()
# Add save path: ex. home/model.safetensors

# Replace with your desired output location.
save_file(
model.state_dict(),
"",
"moondream_finetune.safetensors",
)


if __name__ == "__main__":
"""
Replace paths with your appropriate paths.
To run: python -m moondream.finetune.finetune_region
1 epoch of fine-tuning on the example 'Waste Detection' dataset results in a
2 percentage point increase in mAP.
Dataset: https://universe.roboflow.com/waste-detection-l4m9b/waste-detection-ttdir
"""
main()
43 changes: 21 additions & 22 deletions moondream/finetune/finetune_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from ..torch.moondream import MoondreamModel, MoondreamConfig, text_encoder
from ..torch.text import _produce_hidden, _lm_head, TextConfig

# This is a intended to be a basic starting point. Your optimal hyperparams and data may be different.
# This is a intended to be a basic starting point for fine-tuning the text encoder.
# Your optimal hyperparams and data may be different.
MODEL_PATH = ""
# Your data should end with the eos token. Here is the textual representation.
ANSWER_EOS = "<|endoftext|>"
LR = 5e-6
LR = 3e-6
EPOCHS = 3
GRAD_ACCUM_STEPS = 128

Expand Down Expand Up @@ -108,25 +109,23 @@ def main():
i += 1
with torch.no_grad():
img_emb = model._run_vision_encoder(sample["image"])
bos_emb = text_encoder(
torch.tensor(
[[model.config.tokenizer.bos_id]], device=model.device
),
model.text,
)
question_tokens = model.tokenizer.encode(sample["qa"]["question"]).ids
question_emb = text_encoder(
torch.tensor([[question_tokens]], device=model.device),
model.text,
).squeeze(0)
answer_tokens = model.tokenizer.encode(sample["qa"]["answer"]).ids
answer_emb = text_encoder(
torch.tensor([[answer_tokens]], device=model.device),
model.text,
).squeeze(0)
inputs_embeds = torch.cat(
[bos_emb, img_emb[None], question_emb, answer_emb], dim=1
)
bos_emb = text_encoder(
torch.tensor([[model.config.tokenizer.bos_id]], device=model.device),
model.text,
)
question_tokens = model.tokenizer.encode(sample["qa"]["question"]).ids
question_emb = text_encoder(
torch.tensor([[question_tokens]], device=model.device),
model.text,
).squeeze(0)
answer_tokens = model.tokenizer.encode(sample["qa"]["answer"]).ids
answer_emb = text_encoder(
torch.tensor([[answer_tokens]], device=model.device),
model.text,
).squeeze(0)
inputs_embeds = torch.cat(
[bos_emb, img_emb[None], question_emb, answer_emb], dim=1
)
loss = text_loss(
inputs_embeds=inputs_embeds,
w=model.text,
Expand All @@ -152,7 +151,7 @@ def main():
# Add save path: ex. home/model.safetensors
save_file(
model.state_dict(),
"",
"moondream_finetune.safetensors",
)


Expand Down