Home
vit_float_inference.py
import io
import time
import numpy as np
from numpy_quant.model import Model, onnx_operator_implementation
import onnx.shape_inference
import torch
from datasets import load_dataset
from numpy_quant.tensor import FTensor
from transformers import ViTForImageClassification, ViTImageProcessor
# Obtain ViT torch model using huggingface's infrastructure
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
torch_vit_image_classifier = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
inputs = feature_extractor(image, return_tensors="pt")
# Create ONNX model from torch model
onnx_bytes = io.BytesIO()
torch.onnx.export(
torch_vit_image_classifier,
tuple(inputs.values()),
f=onnx_bytes,
input_names=['inputs'],
output_names=['logits'],
do_constant_folding=True,
opset_version=17,
)
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
# Import ONNX model to numpy-quant
model = Model.from_onnx(onnx_model)
input_arr = np.random.normal(size=(1, 3, 224, 224)).astype(np.float32)
model.inputs[0].data = FTensor(input_arr)
# Iterate through nodes updating all variables in the model.
start = time.time()
for node in model.nodes:
inputs = [i.data for i in node.inputs]
outputs = onnx_operator_implementation(node.op, inputs, node.attrs)
for o, tensor in zip(node.outputs, outputs):
o.data = tensor
inf_time = time.time() - start
print(f"Inference time: {inf_time}s")