您好,登錄后才能下訂單哦!
這篇文章主要講解了如何實現Pytorch通過保存為ONNX模型轉TensorRT5,內容清晰明了,對此有興趣的小伙伴可以學習一下,相信大家閱讀完之后會有幫助。
1 Pytorch以ONNX方式保存模型
def saveONNX(model, filepath): ''' 保存ONNX模型 :param model: 神經網絡模型 :param filepath: 文件保存路徑 ''' # 神經網絡輸入數據類型 dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda') torch.onnx.export(model, dummy_input, filepath, verbose=True)
2 利用TensorRT5中ONNX解析器構建Engine
def ONNX_build_engine(onnx_file_path): ''' 通過加載onnx文件,構建engine :param onnx_file_path: onnx文件路徑 :return: engine ''' # 打印日志 G_LOGGER = trt.Logger(trt.Logger.WARNING) with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser: builder.max_batch_size = 100 builder.max_workspace_size = 1 << 20 print('Loading ONNX file from path {}...'.format(onnx_file_path)) with open(onnx_file_path, 'rb') as model: print('Beginning ONNX file parsing') parser.parse(model.read()) print('Completed parsing of ONNX file') print('Building an engine from file {}; this may take a while...'.format(onnx_file_path)) engine = builder.build_cuda_engine(network) print("Completed creating Engine") # 保存計劃文件 # with open(engine_file_path, "wb") as f: # f.write(engine.serialize()) return engine
3 構建TensorRT運行引擎進行預測
def loadONNX2TensorRT(filepath): ''' 通過onnx文件,構建TensorRT運行引擎 :param filepath: onnx文件路徑 ''' # 計算開始時間 Start = time() engine = self.ONNX_build_engine(filepath) # 讀取測試集 datas = DataLoaders() test_loader = datas.testDataLoader() img, target = next(iter(test_loader)) img = img.numpy() target = target.numpy() img = img.ravel() context = engine.create_execution_context() output = np.empty((100, 10), dtype=np.float32) # 分配內存 d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize) d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize) bindings = [int(d_input), int(d_output)] # pycuda操作緩沖區 stream = cuda.Stream() # 將輸入數據放入device cuda.memcpy_htod_async(d_input, img, stream) # 執行模型 context.execute_async(100, bindings, stream.handle, None) # 將預測結果從從緩沖區取出 cuda.memcpy_dtoh_async(output, d_output, stream) # 線程同步 stream.synchronize() print("Test Case: " + str(target)) print("Prediction: " + str(np.argmax(output, axis=1))) print("tensorrt time:", time() - Start) del context del engine
看完上述內容,是不是對如何實現Pytorch通過保存為ONNX模型轉TensorRT5有進一步的了解,如果還想學習更多內容,歡迎關注億速云行業資訊頻道。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。