Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alpha Release v0.17.0-alpha #84

Merged
merged 12 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions model_demos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ python cv_demos/resnet/pytorch_resnet.py
| [DeiT](cv_demos/deit/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [DenseNet](cv_demos/densenet/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [DistilBERT](nlp_demos/distilbert/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [DLA](cv_demos/dla/) | ✔️ | ✔️ | ✔️ | TBD |
| [DPR](nlp_demos/dpr/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [DLA](cv_demos/dla/) | ✔️ | ✔️ | ✔️ | v0.15.0-alpha |
| [EfficientNet-Lite](cv_demos/efficientnet_lite/) | ✘ | ✘ | ✔️ | v0.12.3 |
| [Falcon-7B](nlp_demos/falcon/) | ✘ | ✘ | ✔️ | v0.12.3 |
| [FLAN-T5](nlp_demos/flant5/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [FPN](cv_demos/fpn/) | ✔️ | ✔️ | ✔️ | TBD |
| [Fuyu-8B](nlp_demos/fuyu8b/) | ✘ | ✘ | ✘ | v0.12.3 |
| [GhostNet](cv_demos/ghostnet/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [GoogLeNet](cv_demos/googlenet/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
Expand All @@ -66,6 +68,8 @@ python cv_demos/resnet/pytorch_resnet.py
| [ResNeXt](cv_demos/resnext/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [RetinaNet](cv_demos/retinanet/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [RoBERTa](nlp_demos/roberta/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [SSD300RESNET50](cv_demos/ssd300_resnet50/) | ✔️ | ✔️ | ✔️ | TBD |
| [SegFormer](cv_demos/segformer/) | ✔️ | ✔️ | ✔️ | TBD |
| [SqueezeBERT](nlp_demos/squeezebert/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [Stable Diffusion](cv_demos/stable_diffusion/) | ✘ | ✘ | ✔️ | v0.12.3 |
| [T5](nlp_demos/t5/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
Expand All @@ -80,6 +84,7 @@ python cv_demos/resnet/pytorch_resnet.py
| [XGLM](nlp_demos/xglm/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [YOLOv3](cv_demos/yolo_v3/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [YOLOv5](cv_demos/yolo_v5/) | ✔️ | ✔️ | ✔️ | v0.12.3 |
| [YOLOv6](cv_demos/yolo_v6/) | ✔️ | ✔️ | ✔️ | TBD |

### Legend

Expand Down
83 changes: 83 additions & 0 deletions model_demos/cv_demos/dla/onnx_dla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import urllib

import onnx
import pybuda
import requests
import torchvision.transforms as transforms
from PIL import Image
from pybuda._C.backend_api import BackendDevice


def run_dla_onnx(variant):
# Load model function
model_name = f"dla_{variant}_pytorch"

# Set PyBuda configuration parameters
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.balancer_policy = "Ribbon"
compiler_cfg.default_df_override = pybuda._C.Float16_b
os.environ["PYBUDA_RIBBON2"] = "1"

available_devices = pybuda.detect_available_devices()
if available_devices:
arch = available_devices[0]
if arch == BackendDevice.Grayskull:
if variant == "dla102x2":
os.environ["PYBUDA_FORCE_CONV_MULTI_OP_FRACTURE"] = "1"

# Load data sample
url = "https://images.rawpixel.com/image_1300/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIyLTA1L3BkMTA2LTA0Ny1jaGltXzEuanBn.jpg"
image = Image.open(requests.get(url, stream=True).raw)
label = "tiger"

# Preprocessing
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
img_tensor = transform(image).unsqueeze(0)

# Download Model
onnx_dir_path = "dla"
onnx_model_path = f"{onnx_dir_path}/{variant}_Opset18.onnx"
if not os.path.exists(onnx_model_path):
if not os.path.exists("dla"):
os.mkdir("dla")
url = f"https://github.com/onnx/models/raw/main/Computer_Vision/{variant}_Opset18_timm/{variant}_Opset18.onnx?download="
response = requests.get(url, stream=True)
with open(onnx_model_path, "wb") as f:
f.write(response.content)

# Load model and prepare for evaluation (inference)
model_name = f"dla_{variant}_onnx"
onnx_model = onnx.load(onnx_model_path)
tt_model = pybuda.OnnxModule(model_name, onnx_model, onnx_model_path)

# run inference on Tenstorrent device
output_q = pybuda.run_inference(tt_model, inputs=[(img_tensor,)])
output = output_q.get()[0].value()

# Get ImageNet class mappings
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
image_classes = urllib.request.urlopen(url)
categories = [s.decode("utf-8").strip() for s in image_classes.readlines()]

# Post processing
predicted_value = output.argmax(-1).item()
predicted_label = categories[predicted_value]

# Print outputs
print(f"True Label: {label} | Predicted Label: {predicted_label}")

# Cleanup model files
os.remove(onnx_model_path)
os.rmdir(onnx_dir_path)


if __name__ == "__main__":
run_dla_onnx("dla34")
14 changes: 6 additions & 8 deletions model_demos/cv_demos/dla/pytorch_dla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run_dla_pytorch(variant):

# Load model function
func = variants_func[variant]
model_name = f"dla_{func.__name__}_pytorch"
model_name = f"dla_{variant}_pytorch"

# Set PyBuda configuration parameters
compiler_cfg = pybuda.config._get_global_compiler_config()
Expand All @@ -49,14 +49,12 @@ def run_dla_pytorch(variant):
available_devices = pybuda.detect_available_devices()
if available_devices:
arch = available_devices[0]
if arch == BackendDevice.Grayskull:
if func.__name__ == "dla102x2":
os.environ["PYBUDA_FORCE_CONV_MULTI_OP_FRACTURE"] = "1"
elif arch == BackendDevice.Wormhole_B0:
if func.__name__ == "dla60x":
if arch == BackendDevice.Wormhole_B0:
if variant == ("dla60", "dla60x"):
compiler_cfg.place_on_new_epoch("concatenate_776.dc.concatenate.0")
elif func.__name__ == "dla60x":
os.environ["TT_BACKEND_OVERLAY_MAX_EXTRA_BLOB_SIZE"] = "20480"
elif arch == BackendDevice.Grayskull:
if variant in ("dla102x2", "dla169"):
os.environ["PYBUDA_FORCE_CONV_MULTI_OP_FRACTURE"] = "1"

# Load data sample
url = "https://images.rawpixel.com/image_1300/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIyLTA1L3BkMTA2LTA0Ny1jaGltXzEuanBn.jpg"
Expand Down
52 changes: 46 additions & 6 deletions model_demos/cv_demos/perceiverio/pytorch_perceiverio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,62 @@
import pybuda
import requests
from PIL import Image
from transformers import AutoImageProcessor, PerceiverForImageClassificationConvProcessing
from transformers import (
AutoImageProcessor,
PerceiverForImageClassificationConvProcessing,
PerceiverForImageClassificationFourier,
PerceiverForImageClassificationLearned,
)


def run_perceiverio_pytorch(variant="deepmind/vision-perceiver-conv"):

# Load ResNet feature extractor and model checkpoint from HuggingFace
model_ckpt = variant
image_processor = AutoImageProcessor.from_pretrained(model_ckpt)
model = PerceiverForImageClassificationConvProcessing.from_pretrained(model_ckpt).eval()
if variant == "deepmind/vision-perceiver-learned":
model = PerceiverForImageClassificationLearned.from_pretrained(model_ckpt)

elif variant == "deepmind/vision-perceiver-conv":
model = PerceiverForImageClassificationConvProcessing.from_pretrained(model_ckpt)

elif variant == "deepmind/vision-perceiver-fourier":
model = PerceiverForImageClassificationFourier.from_pretrained(model_ckpt)

else:
print(f"The model {variant} is not supported")

model.eval()

# Set PyBuda configuration parameters
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.balancer_policy = "Ribbon"
compiler_cfg.default_df_override = pybuda.DataFormat.Float16_b
compiler_cfg.default_dram_parameters = False
compiler_cfg.enable_auto_fusing = False
os.environ["PYBUDA_RIBBON2"] = "1"
os.environ["TT_BACKEND_OVERLAY_MAX_EXTRA_BLOB_SIZE"] = f"{10*1024}"
compiler_cfg.enable_auto_fusing = False

if model_ckpt == "deepmind/vision-perceiver-conv":
compiler_cfg.default_dram_parameters = False
os.environ["TT_BACKEND_OVERLAY_MAX_EXTRA_BLOB_SIZE"] = f"{10*1024}"

if model_ckpt in ["deepmind/vision-perceiver-learned", "deepmind/vision-perceiver-fourier"]:
os.environ["PYBUDA_DISABLE_PADDING_PASS"] = "1"

available_devices = pybuda.detect_available_devices()
if available_devices:
if available_devices[0] == pybuda.BackendDevice.Wormhole_B0:
if model_ckpt == "deepmind/vision-perceiver-conv":
compiler_cfg.balancer_op_override(
"max_pool2d_33.dc.reshape.10.dc.sparse_matmul.13.lc2", "t_stream_shape", (1, 1)
)

if model_ckpt == "deepmind/vision-perceiver-fourier":
os.environ["TT_BACKEND_OVERLAY_MAX_EXTRA_BLOB_SIZE"] = f"{101*1024}"

if available_devices[0] == pybuda.BackendDevice.Grayskull:

if variant in ["deepmind/vision-perceiver-learned", "deepmind/vision-perceiver-fourier"]:
os.environ["TT_BACKEND_OVERLAY_MAX_EXTRA_BLOB_SIZE"] = f"{101*1024}"

# Load data sample
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
Expand All @@ -34,7 +72,9 @@ def run_perceiverio_pytorch(variant="deepmind/vision-perceiver-conv"):
pixel_values = inputs["pixel_values"]

# Run inference on Tenstorrent device
output_q = pybuda.run_inference(pybuda.PyTorchModule("pt_perceiver_io", model), inputs=[(pixel_values,)])
output_q = pybuda.run_inference(
pybuda.PyTorchModule("pt_" + str(model_ckpt.split("/")[-1].replace("-", "_")), model), inputs=[(pixel_values,)]
)
output = output_q.get() # return last queue object

# Data postprocessing
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os

import pybuda
import requests
from PIL import Image
from transformers import AutoImageProcessor, SegformerConfig, SegformerForImageClassification


def get_sample_data(model_name):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained(model_name)
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
return pixel_values


def run_segformer_image_classification_pytorch(variant="nvidia/mit-b0"):

# Set PyBuda configuration parameters
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.balancer_policy = "Ribbon"
compiler_cfg.default_df_override = pybuda.DataFormat.Float16_b
os.environ["PYBUDA_RIBBON2"] = "1"
os.environ["PYBUDA_DISABLE_PADDING_PASS"] = "1"

available_devices = pybuda.detect_available_devices()
if available_devices:
if available_devices[0] == pybuda.BackendDevice.Wormhole_B0:
if variant in ["nvidia/mit-b1", "nvidia/mit-b2", "nvidia/mit-b3", "nvidia/mit-b4", "nvidia/mit-b5"]:

os.environ["PYBUDA_FORCE_CONV_MULTI_OP_FRACTURE"] = "1"

# Set model configurations
config = SegformerConfig.from_pretrained(variant)
config_dict = config.to_dict()
config_dict["return_dict"] = False
config = SegformerConfig(**config_dict)

# Load the model from HuggingFace
model = SegformerForImageClassification.from_pretrained(variant, config=config)
model.eval()

# Load the sample image
pixel_values = get_sample_data(variant)

# Create PyBuda module from PyTorch model
tt_model = pybuda.PyTorchModule("pt_" + str(variant.split("/")[-1].replace("-", "_")), model)

# run inference on Tenstorrent device
output_q = pybuda.run_inference(tt_model, inputs=[(pixel_values,)])
output = output_q.get()

# Print output
predicted_value = output[0].value().argmax(-1).item()
predicted_label = model.config.id2label[predicted_value]
print("Predicted label : ", predicted_label)


if __name__ == "__main__":
run_segformer_image_classification_pytorch()
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os

import pybuda
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation

torch.multiprocessing.set_sharing_strategy("file_system")


def get_sample_data(model_name):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained(model_name)
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
return pixel_values


def run_segformer_semseg_pytorch(variant="nvidia/segformer-b0-finetuned-ade-512-512"):

# Set PyBuda configuration parameters
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.balancer_policy = "Ribbon"
compiler_cfg.default_df_override = pybuda.DataFormat.Float16_b
os.environ["PYBUDA_RIBBON2"] = "1"
os.environ["PYBUDA_DISABLE_PADDING_PASS"] = "1"

available_devices = pybuda.detect_available_devices()
if available_devices:
if available_devices[0] == pybuda.BackendDevice.Wormhole_B0:
if variant in [
"nvidia/segformer-b1-finetuned-ade-512-512",
"nvidia/segformer-b2-finetuned-ade-512-512",
"nvidia/segformer-b3-finetuned-ade-512-512",
"nvidia/segformer-b4-finetuned-ade-512-512",
]:

os.environ["PYBUDA_FORCE_CONV_MULTI_OP_FRACTURE"] = "1"

elif available_devices[0] == pybuda.BackendDevice.Grayskull:
if variant in [
"nvidia/segformer-b2-finetuned-ade-512-512",
"nvidia/segformer-b3-finetuned-ade-512-512",
"nvidia/segformer-b4-finetuned-ade-512-512",
]:
compiler_cfg.amp_level = 1

# Load the model from HuggingFace
model = SegformerForSemanticSegmentation.from_pretrained(variant)
model.eval()

# Load the sample image
pixel_values = get_sample_data(variant)

# Create PyBuda module from PyTorch model
tt_model = pybuda.PyTorchModule("pt_" + str(variant.split("/")[-1].replace("-", "_")), model)

# run inference on Tenstorrent device
output_q = pybuda.run_inference(tt_model, inputs=[(pixel_values,)])
output = output_q.get()[0].value()

# Print output
print("output=", output)


if __name__ == "__main__":
run_segformer_semseg_pytorch()
2 changes: 1 addition & 1 deletion model_demos/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ yolov5==7.0.9 # For YOLOv5
soundfile==0.12.1 # For Whisper
librosa==0.10.0 # For Whisper
segmentation-models-pytorch==0.3.3 # For U-Net
diffusers==0.14.0 # For Stable Diffusion
diffusers==0.27.2 # For Stable Diffusion
22 changes: 22 additions & 0 deletions model_demos/tests/test_onnx_dla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from cv_demos.dla.onnx_dla import run_dla_onnx

variants = [
"dla34",
"dla46_c",
"dla46x_c",
"dla60x_c",
"dla60",
"dla60x",
"dla102",
"dla102x",
"dla102x2",
"dla169",
]


@pytest.mark.dla
@pytest.mark.parametrize("variant", variants, ids=variants)
def test_dla_onnx(clear_pybuda, variant):
run_dla_onnx(variant)
Loading