Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Composability #219

Merged
merged 7 commits into from
Jan 7, 2025
Merged

Composability #219

merged 7 commits into from
Jan 7, 2025

Conversation

rahul-tuli
Copy link
Member

@rahul-tuli rahul-tuli commented Dec 2, 2024

This PR enables composability in the ModelCompressor!

Features Added

  1. Support for targets and ignore in SparseCompressors
    Adds flexibility to specify layers for compression.

  2. Composability for compress/decompress pathways

  3. Typing updates
    Improves type annotations for some methods.

  4. New tests

    • Composability Test to validate the integration of compression features.
    • FP8 Test to verify composability with FP8 quantization.
    • Tests for expanding targets for SparseCompressors
    • Adds some testing utilities as well

Test Script

Expand to view the test script
import torch
import argparse
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from llmcompressor.transformers import oneshot
from compressed_tensors import ModelCompressor, QUANTIZATION_CONFIG_NAME

# Constants for default values
DEFAULT_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
DEFAULT_DATASET = "gsm8k"
DEFAULT_DATASET_CONFIG_NAME = "main"


def load_model_and_tokenizer(model_id: str):
    """Load the model and tokenizer from the Hugging Face Hub."""
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=model_id, 
        device_map="auto", 
        torch_dtype="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_id)
    return model, tokenizer


def apply_recipe(model: torch.nn.Module, recipe_path: str, dataset: str, tokenizer, dataset_config_name: str = None):
    """Apply a sparsity or quantization recipe to the model."""
    oneshot(
        model=model, 
        recipe=recipe_path, 
        dataset=dataset, 
        tokenizer=tokenizer,
        dataset_config_name=dataset_config_name,
    )


def generate_sample_output(model: torch.nn.Module, tokenizer):
    """Generate a sample output to verify model functionality."""
    print("========== SAMPLE GENERATION ==============")
    input_text = "Hello my name is"
    input_ids = tokenizer(
        text=input_text, 
        return_tensors="pt"
    ).input_ids.to("cuda")
    output = model.generate(
        input_ids=input_ids, 
        max_new_tokens=20
    )
    print(tokenizer.decode(output[0]))
    print("==========================================")


def save_model_and_tokenizer(model: torch.nn.Module, tokenizer, save_dir: str, save_compressed: bool = True):
    """Save the model and tokenizer locally."""
    model.save_pretrained(save_directory=save_dir, save_compressed=save_compressed)
    tokenizer.save_pretrained(save_directory=save_dir)
    return save_dir


def get_save_directory(model_id: str, recipe_style: str, dataset_name: str, compressor_name:str) -> str:
    """Generate a save directory name based on the model and recipe style."""
    model_name = model_id.split("/")[-1] if "/" in model_id else model_id
    
    if dataset_name not in model_name:
        model_name = f"{model_name}-{dataset_name}"
    
    if "-Uncompressed" in model_name:
        model_name = model_name.replace("-Uncompressed", "")
    
    return f"nm-testing/{model_name}-{recipe_style}-{compressor_name}"


def parse_arguments():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="Apply sparsity or quantization recipes and save the model locally.")
    parser.add_argument(
        "--model_id",
        type=str,
        default=DEFAULT_MODEL_ID,
        help="Model ID from Hugging Face Hub or local directory."
    )
    parser.add_argument(
        "--recipe_path",
        type=str,
        required=True,
        help="Path to the recipe YAML file."
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default=DEFAULT_DATASET,
        help="Dataset name for recipe application."
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="Dataset configuration name."
    )
    parser.add_argument(
        "--save_uncompressed",
        action="store_true",
        help="Flag to check decompression of the model."
    )
    parser.add_argument(
        "--check_decompression",
        action="store_true",
        help="Flag to check decompression of the model."
    )
    parser.add_argument(
        "--print_generation",
        action="store_true",
        help="Flag to print a sample generation."
    )

    args = parser.parse_args()
    if args.dataset == DEFAULT_DATASET and args.dataset_config_name is None:
        args.dataset_config_name = DEFAULT_DATASET_CONFIG_NAME

    print("\nParsed CLI Arguments:")
    for arg, value in vars(args).items():
        print(f"  {arg}: {value}")
    print("\n")
    return args


def check_decompression(uncompressed_model: torch.nn.Module, compressed_path: str, base_model_id: str):
    """Verify that the decompressed model matches the uncompressed model."""
    decompressed_model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=base_model_id, 
        torch_dtype="auto", 
        device_map="auto"
    )

    config = AutoConfig.from_pretrained(pretrained_model_name_or_path=compressed_path)
    compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)

    model_compressor = ModelCompressor.from_compression_config(compression_config)
    model_compressor.decompress(
        model=decompressed_model, 
        model_path=compressed_path
    )

    # decompressed inference
    print("Running inference on the decompressed model.")
    generate_sample_output(model=decompressed_model, tokenizer=AutoTokenizer.from_pretrained(base_model_id))
    

    decompressed_state = decompressed_model.state_dict()
    uncompressed_state = uncompressed_model.state_dict()

    for key in decompressed_state.keys():
        assert key in uncompressed_state, f"Missing tensor: {key}"
        if key.endswith("weight"):
            original_tensor = uncompressed_state[key]
            decompressed_tensor = decompressed_state[key].to(original_tensor.dtype)
            diff = torch.abs(original_tensor - decompressed_tensor)
            if torch.any(diff > 0.06).item():
                print(key, torch.max(diff))

    print("Decompressed model parameters match the original model.")
    print("Inference test passed on the decompressed model.")


def main():
    """Main workflow for recipe application, saving, and optional checks."""
    args = parse_arguments()
    recipe_style = Path(args.recipe_path).stem

    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_id=args.model_id)

    # Apply sparsity or quantization recipe
    apply_recipe(
        model=model, 
        recipe_path=args.recipe_path, 
        dataset=args.dataset,
        dataset_config_name=args.dataset_config_name,
        tokenizer=tokenizer
    )

    # Optionally generate a sample output
    if args.print_generation:
        generate_sample_output(model=model, tokenizer=tokenizer)


    if not args.save_uncompressed:
        compressor_name = "BitM"
    else:
        compressor_name = "Uncompressed"

    # Save the processed model
    save_dir = get_save_directory(
        model_id=args.model_id, 
        recipe_style=recipe_style,
        dataset_name=args.dataset,
        compressor_name=compressor_name
    )
    compressed_path = save_model_and_tokenizer(
        model=model, 
        tokenizer=tokenizer, 
        save_dir=save_dir,
        save_compressed=not args.save_uncompressed
    )
    print(f"Model saved successfully at: {compressed_path}")

    # Optional decompression check
    if args.check_decompression:
        check_decompression(
            uncompressed_model=model, 
            compressed_path=compressed_path, 
            base_model_id=args.model_id
        )


if __name__ == "__main__":
    main()

Expand to view the invocation command
python oneshot_and_save.py --model_id nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-Uncompressed \
--recipe_path /home/rahul/llm-compressor/local/compressors/bitmask/recipes/chnl_wts_per_tok_dyn_act_fp8.yaml \
--dataset gsm8k --dataset_config_name main --check_decompression --print_generation

Expand to view the recipe
quant_stage:
    quant_modifiers:
        QuantizationModifier:
            ignore: ["lm_head"]
            config_groups:
                group_0:
                    weights:
                        num_bits: 8
                        type: float
                        strategy: channel
                        dynamic: false
                        symmetric: true
                    input_activations:
                        num_bits: 8
                        type: float
                        strategy: token
                        dynamic: true
                        symmetric: true
                    targets: ["Linear"]
    pruning_modifiers:
        ConstantPruningModifier:
            targets: [
                're:.*q_proj.weight',
                're:.*k_proj.weight', 
                're:.*v_proj.weight',
                're:.*o_proj.weight',
                're:.*gate_proj.weight',
                're:.*up_proj.weight',
                're:.*down_proj.weight',
            ]
            start: 0

Expand to view script output
2024-12-19T14:34:13.424544+0000 | finalize | INFO - Compression lifecycle finalized for 1 modifiers
========== SAMPLE GENERATION ==============
<s> Hello my name is my name. I am my name. I am my name. I am my name. I am
==========================================
Checking whether model follows 2:4 sparsity structure: 100%|██████████████████████████████████████████████████████████████████| 155/155 [00:00<00:00, 162.49it/s]
2024-12-19T14:34:22.923130+0000 | get_model_compressor | INFO - Inferring a sparsity configuration requires a global sparsity calculation. This can be costly for large models. To skip the calculation of compression statistics set skip_compression_stats=True
Calculating model sparsity: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 509/509 [00:00<00:00, 608.69it/s]
Checking whether model follows 2:4 sparsity structure: 100%|██████████████████████████████████████████████████████████████████| 155/155 [00:00<00:00, 194.59it/s]
Calculating quantization compression ratio: 246it [00:02, 99.84it/s] 
Quantized Compression: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 509/509 [00:00<00:00, 982.05it/s]
Compressing model: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 355/355 [00:01<00:00, 196.93it/s]
Model saved successfully at: nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM
Decompressing model: 355it [00:54,  6.53it/s] 
Decompressing model: 154it [00:00, 2357.32it/s]
Running inference on the decompressed model.
========== SAMPLE GENERATION ==============
<s> Hello my name is my name. I am my name. I am my name. I am my name. I am
==========================================
Decompressed model parameters match the original model.
Inference test passed on the decompressed model.

@rahul-tuli rahul-tuli changed the base branch from main to add-24-compressor December 2, 2024 22:21
horheynm
horheynm previously approved these changes Dec 3, 2024
@horheynm horheynm marked this pull request as ready for review December 3, 2024 18:33
@rahul-tuli rahul-tuli marked this pull request as draft December 17, 2024 20:09
@rahul-tuli rahul-tuli changed the base branch from add-24-compressor to add-targets-and-ignore-support December 17, 2024 20:20
@rahul-tuli rahul-tuli changed the base branch from add-targets-and-ignore-support to main December 19, 2024 14:59
@rahul-tuli rahul-tuli dismissed horheynm’s stale review December 19, 2024 14:59

The base branch was changed.

@rahul-tuli rahul-tuli requested a review from dsikka December 19, 2024 15:15
@rahul-tuli rahul-tuli self-assigned this Dec 19, 2024
@rahul-tuli rahul-tuli marked this pull request as ready for review December 19, 2024 15:16
horheynm
horheynm previously approved these changes Dec 19, 2024
tests/testing_utils.py Outdated Show resolved Hide resolved
tests/testing_utils.py Outdated Show resolved Hide resolved
tests/test_utils/test_safetensors_load.py Show resolved Hide resolved
src/compressed_tensors/quantization/lifecycle/apply.py Outdated Show resolved Hide resolved
src/compressed_tensors/utils/safetensors_load.py Outdated Show resolved Hide resolved
Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good.
I'd check it out in llmcompressor and make sure ci and nightly tests pass.
After than, can merge.

Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, approval pending tests passing

Enable: Operations on state_dict to allow composability
Add: Composability for compress/decompress pathways
Update: Typing for a few methods
Add: Composability Test
Add: Some testing utils
Update _replace_weight to work with updates from `85b473e`
Add docstring to _replace_weights
Update failing test
Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Good job integrating this cleanly with quantization!

src/compressed_tensors/utils/safetensors_load.py Outdated Show resolved Hide resolved
dsikka
dsikka previously approved these changes Jan 3, 2025
Copy link
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be some opportunities for code clarity, but looks correct to me

@kylesayrs
Copy link
Contributor

LGTM! Although it would be really nice to address simplifying how the state dict is loaded in a separate PR when we get the chance.

#219 (comment)
#219 (comment)

@dsikka dsikka merged commit 7801f00 into main Jan 7, 2025
1 check passed
@dsikka dsikka deleted the composability branch January 7, 2025 00:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants