Home
import_vit_from_onnx.py
import io
from numpy_quant.model import Model
import onnx.shape_inference
import torch
from datasets import load_dataset
from transformers import ViTForImageClassification, ViTImageProcessor
# Obtain ViT torch model using huggingface's infrastructure
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
torch_vit_image_classifier = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
inputs = feature_extractor(image, return_tensors="pt")
# Create ONNX model from torch model
onnx_bytes = io.BytesIO()
torch.onnx.export(
torch_vit_image_classifier,
tuple(inputs.values()),
f=onnx_bytes,
input_names=['inputs'],
output_names=['logits'],
do_constant_folding=True,
opset_version=17,
)
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
# Import ONNX model to numpy-quant
model = Model.from_onnx(onnx_model)
# Print node, values, inputs and outputs of the model
print(model)