This repository is a fork of the original LLaVA project, modified to fine-tune LLaVA 1.6 Mistral-7B using ORPO on the RLAIF-V dataset with LoRA and QLoRA. This approach enhances multimodal understanding by leveraging reinforcement learning with AI feedback (RLAIF) while keeping training efficient with parameter-efficient fine-tuning techniques.
- Fine-Tune LLaVA 1.6 Mistral-7B on the RLAIF-V dataset
- Optimized Rank Preference Optimization (ORPO) for alignment
- LoRA & QLoRA fine-tuning for efficient model adaptation
- Support for xFormers & FlashAttention to reduce memory overhead
- Evaluation scripts for LoRA fine-tuned models
- Dataset preparation based on
create_splits.py
- Interactive Gradio demo for inference
git clone https://github.com/YOUR_USERNAME/LLaVA1.6-Mistral-Finetune-ORPO-RLAIF-V.git
cd LLaVA1.6-Mistral-Finetune-ORPO-RLAIF-V
conda create -n llava python=3.10 -y
conda activate llava
pip install --upgrade pip
pip install -e ".[train]"
pip install -r additional_requirements.txt
pip install flash-attn --no-build-isolation
The base LLaVA 1.6 - Mistral 7B model can be downloaded from Hugging Face:
git lfs install
git clone https://huggingface.co/liuhaotian/LLaVA-1.6-Mistral-7B checkpoints/llava1.6-mistral-7b
This fine-tuning pipeline uses the RLAIF-V dataset, which is processed using create_splits.py
.
- Run dataset split creation:
python create_splits.py
- The processed dataset should be structured and will be saved locally as:
./rlaif-v-train-only/ ./rlaif-v-validation-only/
- Ensure dataset paths are correctly referenced in
train_orpo.py
.
To incorporate HH-RLHF, download and process it:
python download_hh-rlhf_dataset.py
This will save the dataset locally for joint training.
Run the fine-tuning script with the required arguments:
python train_orpo.py \
--train_data_path ./rlaif-v-train-only \
--val_data_path ./rlaif-v-validation-only \
--model_name ../../llava1.6-mistral-7b \
--output_dir ./llava-output \
--per_device_train_batch_size 32 \
--gradient_accumulation_steps 32 \
--max_steps 500 \
--logging_steps 10 \
--report_to wandb \
--save_strategy steps \
--save_steps 100 \
--save_total_limit 20 \
--learning_rate 2e-5 \
--weight_decay 0.01 \
--warmup_ratio 0.03 \
--lr_scheduler_type cosine \
--gradient_checkpointing \
--remove_unused_columns False \
--bf16 False \
--max_length 2048
To enable QLoRA, add the following flag:
--use_qlora
For joint training with HH-RLHF, use:
python train_orpo_RLAIF_HH-RLHF.py
- Model Checkpoints:
checkpoints/llava1.6-mistral-7b-finetune
- LoRA Adapters:
checkpoints/llava1.6-mistral-7b-finetune_lora
- Training Script:
train_orpo.py
- Hyperparameters: Configurable inside
train_orpo.py
Once training is complete, you can evaluate the fine-tuned model.
python eval_lora_orpo_rlaif.py \
--model_path llava1.6-mistral-7b \
--checkpoint_path ./llava1.6-mistral-7b-RLAIF-V-ORPO/checkpoint-99000 \
--dataset_path ./rlaif-v-validation-only \
--batch_size 512 \
--report_wandb \
--evaluate_original False
This evaluation script reports key performance metrics, including F1-score and ROUGE, to assess the fine-tuned model’s quality.
Before deployment, you need to merge the LoRA adapter with the base model:
python -m peft.merge_adapters \
--base_model_path llava1.6-mistral-7b \
--lora_adapter_path checkpoints/llava1.6-mistral-7b-finetune_lora \
--output_path checkpoints/llava1.6-mistral-7b-merged
This will create a fully merged model in checkpoints/llava1.6-mistral-7b-merged
.
You can test your fine-tuned model using Gradio with:
pip install --upgrade gradio fastapi pydantic
Run the Gradio application:
python gradio_app.py \
--base_model_path ../../llava1.6-mistral-7b \
--lora_weights_path ./llava1.6-mistral-7b-RLAIF-V-ORPO/checkpoint-99000 \
--share
This launches an interactive UI for testing the fine-tuned multimodal model.
This project follows the Apache 2.0 License. Users must comply with any dataset/model-specific licensing agreements.
This work builds upon:
- LLaVA for multimodal instruction tuning.
- Mistral 7B as the base LLM.
- RLAIF-V Dataset for AI feedback-based fine-tuning.
- HH-RLHF Dataset for improving model alignment.