Home

import_vit_from_onnx.py

import io
from numpy_quant.model import Model
import onnx.shape_inference
import torch
from datasets import load_dataset
from transformers import ViTForImageClassification, ViTImageProcessor


# Obtain ViT torch model using huggingface's infrastructure
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
torch_vit_image_classifier = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
inputs = feature_extractor(image, return_tensors="pt")

# Create ONNX model from torch model
onnx_bytes = io.BytesIO()
torch.onnx.export(
    torch_vit_image_classifier,
    tuple(inputs.values()),
    f=onnx_bytes,
    input_names=['inputs'],
    output_names=['logits'],
    do_constant_folding=True,
    opset_version=17,
)
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)

# Import ONNX model to numpy-quant
model = Model.from_onnx(onnx_model)

# Print node, values, inputs and outputs of the model
print(model)