Home

vit_float_inference.py

import io
import time

import numpy as np
from numpy_quant.model import Model, onnx_operator_implementation
import onnx.shape_inference
import torch
from datasets import load_dataset
from numpy_quant.tensor import FTensor
from transformers import ViTForImageClassification, ViTImageProcessor


# Obtain ViT torch model using huggingface's infrastructure
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
torch_vit_image_classifier = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
inputs = feature_extractor(image, return_tensors="pt")

# Create ONNX model from torch model
onnx_bytes = io.BytesIO()
torch.onnx.export(
    torch_vit_image_classifier,
    tuple(inputs.values()),
    f=onnx_bytes,
    input_names=['inputs'],
    output_names=['logits'],
    do_constant_folding=True,
    opset_version=17,
)
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)

# Import ONNX model to numpy-quant
model = Model.from_onnx(onnx_model)

input_arr = np.random.normal(size=(1, 3, 224, 224)).astype(np.float32)
model.inputs[0].data = FTensor(input_arr)

# Iterate through nodes updating all variables in the model.
start = time.time()
for node in model.nodes:
    inputs = [i.data for i in node.inputs]
    outputs = onnx_operator_implementation(node.op, inputs, node.attrs)
    for o, tensor in zip(node.outputs, outputs):
        o.data = tensor
inf_time = time.time() - start
print(f"Inference time: {inf_time}s")