torch.load函數用于從硬盤上加載保存的PyTorch模型或張量。它可以加載包含模型權重、網絡結構和訓練狀態等信息的.pth、.pt、.pkl等文件,并返回一個包含加載的對象的Python字典。
使用torch.load函數可以方便地加載預訓練模型,以便在新任務上進行微調或推理。加載的模型可以用于評估、生成預測或繼續訓練。
示例用法:
model = torch.load('model.pth')
此外,torch.load函數還可以通過指定一個map_location參數,將模型加載到指定的設備上,例如將模型加載到GPU上:
model = torch.load('model.pth', map_location='cuda:0')