要將PyTorch模型轉換為ONNX格式,可以按照以下步驟操作:
pip install torch torchvision onnx
import torch
import torch.onnx as onnx
# 加載PyTorch模型
model = torch.load('path_to_model.pth')
# 設置模型為評估模式
model.eval()
# 創建模擬輸入數據
dummy_input = torch.randn(1, 3, 224, 224)
# 導出模型為ONNX格式
onnx.export(model, dummy_input, 'path_to_output.onnx')
在上面的示例中,path_to_model.pth
是你的PyTorch模型的文件路徑,path_to_output.onnx
是導出的ONNX模型的文件路徑。