Skip to content

Commit

Permalink
[not4land] Local torchao benchmark
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jan 16, 2025
1 parent 675fb8f commit cf4b61e
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 17 deletions.
13 changes: 13 additions & 0 deletions manual_cron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
target_hour=22
target_min=00
while true
do
current_hour=$(date +%H)
current_min=$(date +%M)
if [ $current_hour -eq $target_hour ] && [ $current_min -eq $target_min ] ; then
echo "Cron job started at $(date)"
sh cron_script.sh > local_cron_log 2>local_cron_err
echo "Cron job executed at $(date)"
fi
sleep 60
done
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pytest
pytest-benchmark
requests
tabulate
git+https://github.com/huggingface/pytorch-image-models.git@730b907
# git+https://github.com/huggingface/pytorch-image-models.git@730b907
# this version of transformers is required by linger-kernel
# https://github.com/linkedin/Liger-Kernel/blob/main/pyproject.toml#L23
transformers==4.44.2
Expand Down
57 changes: 57 additions & 0 deletions upload_to_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import io
import json
from functools import lru_cache
import boto3
from typing import Any
import gzip

@lru_cache
def get_s3_resource() -> Any:
return boto3.resource("s3")

def upload_to_s3(
bucket_name: str,
key: str,
json_path: str,
) -> None:
print(f"Writing {json_path} documents to S3")
data = []
with open(f"{os.path.splitext(json_path)[0]}.json", "r") as f:
for l in f.readlines():
data.append(json.loads(l))

body = io.StringIO()
for benchmark_entry in data:
json.dump(benchmark_entry, body)
body.write("\n")

try:
get_s3_resource().Object(
f"{bucket_name}",
f"{key}",
).put(
Body=body.getvalue(),
ContentType="application/json",
)
except e:
print("fail to upload to s3:", e)
return
print("Done!")

if __name__ == "__main__":
import argparse
import datetime
parser = argparse.ArgumentParser(description="Upload benchmark result json file to clickhouse")
parser.add_argument("--json-path", type=str, help="json file path to upload to click house", required=True)
args = parser.parse_args()
today = datetime.date.today()
today = datetime.datetime.combine(today, datetime.time.min)
today_timestamp = str(int(today.timestamp()))
print("Today timestamp:", today_timestamp)
import subprocess
# Execute the command and capture the output
output = subprocess.check_output(['hostname', '-s'])
# Decode the output from bytes to string
hostname = output.decode('utf-8').strip()
upload_to_s3("ossci-benchmarks", f"v3/pytorch/ao/{hostname}/torchbenchmark-torchbench-" + today_timestamp + ".json", args.json_path)
17 changes: 14 additions & 3 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
same,
)
from torch._logging.scribe import open_source_signpost
from userbenchmark.dynamo.dynamobench.utils import benchmark_and_write_json_result


try:
Expand Down Expand Up @@ -555,8 +556,17 @@ def output_signpost(data, args, suite, error=None):
)


def nothing(f):
return f
def nothing(model_iter_fn):
def _apply(module: torch.nn.Module, example_inputs: Any):
if isinstance(example_inputs, dict):
args = ()
kwargs = example_inputs
else:
args = example_inputs
kwargs = {}
benchmark_and_write_json_result(module, args, kwargs, "noquant", "cuda", compile=False)
model_iter_fn(module, example_inputs)
return _apply


@functools.lru_cache(None)
Expand Down Expand Up @@ -4147,8 +4157,9 @@ def get_example_inputs(self):
"int8dynamic",
"int8weightonly",
"int4weightonly",
"autoquant",
"noquant",
"autoquant",
"autoquant-all",
],
default=None,
help="Measure speedup of torchao quantization with TorchInductor baseline",
Expand Down
48 changes: 46 additions & 2 deletions userbenchmark/dynamo/dynamobench/torchao_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable

import torch

from userbenchmark.dynamo.dynamobench.utils import benchmark_and_write_json_result

def setup_baseline():
from torchao.quantization.utils import recommended_inductor_config_setter
Expand All @@ -20,10 +20,21 @@ def torchao_optimize_ctx(quantization: str):
quantize_,
)
from torchao.utils import unwrap_tensor_subclass
import torchao

def inner(model_iter_fn: Callable):
def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
if getattr(module, "_quantized", None) is None:
if quantization == "noquant":
if isinstance(example_inputs, dict):
args = ()
kwargs = example_inputs
else:
args = example_inputs
kwargs = {}

benchmark_and_write_json_result(module, args, kwargs, "noquant", "cuda")

if quantization == "int8dynamic":
quantize_(
module,
Expand All @@ -34,7 +45,30 @@ def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
quantize_(module, int8_weight_only(), set_inductor_config=False)
elif quantization == "int4weightonly":
quantize_(module, int4_weight_only(), set_inductor_config=False)
if quantization == "autoquant":
if quantization == "autoquant-all":
autoquant(module, error_on_unseen=False, set_inductor_config=False, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST)
if isinstance(example_inputs, dict):
module(**example_inputs)
else:
module(*example_inputs)
from torchao.quantization.autoquant import AUTOQUANT_CACHE

if len(AUTOQUANT_CACHE) == 0:
raise Exception( # noqa: TRY002`
"NotAutoquantizable"
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
)

if isinstance(example_inputs, dict):
args = ()
kwargs = example_inputs
else:
args = example_inputs
kwargs = {}

torchao.quantization.utils.recommended_inductor_config_setter()
benchmark_and_write_json_result(module, args, kwargs, quantization, "cuda")
elif quantization == "autoquant":
autoquant(module, error_on_unseen=False, set_inductor_config=False)
if isinstance(example_inputs, dict):
module(**example_inputs)
Expand All @@ -47,6 +81,16 @@ def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
"NotAutoquantizable"
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
)

if isinstance(example_inputs, dict):
args = ()
kwargs = example_inputs
else:
args = example_inputs
kwargs = {}

torchao.quantization.utils.recommended_inductor_config_setter()
benchmark_and_write_json_result(module, args, kwargs, quantization, "cuda")
else:
unwrap_tensor_subclass(module)
setattr(module, "_quantized", True) # noqa: B010
Expand Down
86 changes: 86 additions & 0 deletions userbenchmark/dynamo/dynamobench/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
import torch
import platform
import os
import time
import datetime
import hashlib

def get_arch_name() -> str:
if torch.cuda.is_available():
return torch.cuda.get_device_name()
else:
# This returns x86_64 or arm64 (for aarch64)
return platform.machine()


def write_json_result(output_json_path, headers, row):
"""
Write the result into JSON format, so that it can be uploaded to the benchmark database
to be displayed on OSS dashboard. The JSON format is defined at
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
mapping_headers = {headers[i]: v for i, v in enumerate(row)}
today = datetime.date.today()
sha_hash = hashlib.sha256(str(today).encode("utf-8")).hexdigest()
first_second = datetime.datetime.combine(today, datetime.time.min)
workflow_id = int(first_second.timestamp())
job_id = workflow_id + 1
record = {
"timestamp": int(time.time()),
"schema_version": "v3",
"name": "devvm local benchmark",
"repo": "pytorch/ao",
"head_branch": "main",
"head_sha": sha_hash,
"workflow_id": workflow_id,
"run_attempt": 1,
"job_id": job_id,
"benchmark": {
"name": "TorchAO benchmark",
"mode": "inference",
"dtype": mapping_headers["dtype"],
"extra_info": {
"device": mapping_headers["device"],
"arch": mapping_headers["arch"],
"min_sqnr": None,
"compile": mapping_headers["compile"],
},
},
"model": {
"name": mapping_headers["name"],
"type": "model",
# TODO: make this configurable
"origins": ["torchbench"],
},
"metric": {
"name": mapping_headers["metric"],
"benchmark_values": [mapping_headers["actual"]],
"target_value": mapping_headers["target"],
},
}

with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f:
print(json.dumps(record), file=f)

def benchmark_and_write_json_result(model, args, kwargs, quantization, device, compile=True):
print(quantization + " run")
from torchao.utils import benchmark_model, profiler_runner
if compile:
model = torch.compile(model, mode="max-autotune")
benchmark_model(model, 20, args, kwargs)
elapsed_time = benchmark_model(model, 100, args, kwargs)
print("elapsed_time: ", elapsed_time, " milliseconds")

if hasattr(model, "_orig_mod"):
name = model._orig_mod.__class__.__name__
else:
# eager
name = model.__class__.__name__

headers = ["name", "dtype", "compile", "device", "arch", "metric", "actual", "target"]
arch = get_arch_name()
dtype = quantization
performance_result = [name, dtype, compile, device, arch, "time_ms(avg)", elapsed_time, None]
_OUTPUT_JSON_PATH = "benchmark_results"
write_json_result(_OUTPUT_JSON_PATH, headers, performance_result)
6 changes: 2 additions & 4 deletions userbenchmark/group_bench/configs/torch_ao.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,5 @@ metrics:
test_group:
test_batch_size_default:
subgroup:
- extra_args:
- extra_args: --quantization int8dynamic
- extra_args: --quantization int8weightonly
- extra_args: --quantization int4weightonly
- extra_args: --quantization noquant
- extra_args: --quantization autoquant
31 changes: 24 additions & 7 deletions userbenchmark/torchao/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@


def _get_ci_args(
backend: str, modelset: str, dtype, mode: str, device: str, experiment: str
quantization: str, modelset: str, dtype, mode: str, device: str, experiment: str
) -> List[List[str]]:
if modelset == "timm":
modelset_full_name = "timm_models"
else:
modelset_full_name = modelset
output_file_name = f"torchao_{backend}_{modelset_full_name}_{dtype}_{mode}_{device}_{experiment}.csv"
output_file_name = f"torchao_{quantization}_{modelset_full_name}_{dtype}_{mode}_{device}_{experiment}.csv"
ci_args = [
"--progress",
f"--{modelset}",
"--quantization",
f"{backend}",
f"{quantization}",
f"--{mode}",
f"--{dtype}",
f"--{experiment}",
Expand All @@ -32,16 +32,33 @@ def _get_ci_args(
]
return ci_args

def _get_eager_baseline_args(quantization: str, modelset: str, dtype, mode: str, device: str, experiment: str):
if modelset == "timm":
modelset_full_name = "timm_models"
else:
modelset_full_name = modelset
output_file_name = f"torchao_{quantization}_{modelset_full_name}_{dtype}_{mode}_{device}_{experiment}_eager.csv"
ci_args = [
"--progress",
f"--{modelset}",
f"--{mode}",
f"--{dtype}",
f"--{experiment}",
"--nothing",
"--output",
f"{str(OUTPUT_DIR.joinpath(output_file_name).resolve())}",
]
return ci_args

def _get_full_ci_args(modelset: str) -> List[List[str]]:
backends = ["autoquant", "int8dynamic", "int8weightonly", "noquant"]
quantizations = ["autoquant-all", "autoquant", "noquant"]
modelset = [modelset]
dtype = ["bfloat16"]
mode = ["inference"]
device = ["cuda"]
experiment = ["performance", "accuracy"]
cfgs = itertools.product(*[backends, modelset, dtype, mode, device, experiment])
return [_get_ci_args(*cfg) for cfg in cfgs]
experiment = ["performance"]
cfgs = itertools.product(*[quantizations, modelset, dtype, mode, device, experiment])
return [_get_ci_args(*cfg) for cfg in cfgs] + [_get_eager_baseline_args("noquant", modelset[0], dtype[0], mode[0], device[0], experiment[0])]


def _get_output(pt2_args):
Expand Down

0 comments on commit cf4b61e

Please sign in to comment.