在PyTorch中,通常通過使用torch.utils.data.Dataset
和torch.utils.data.DataLoader
來加載和處理數據集。
首先,創建一個自定義的數據集類,繼承自torch.utils.data.Dataset
,并實現__len__
和__getitem__
方法。在__getitem__
方法中,可以根據索引加載和預處理數據。
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
# 進行數據預處理
return sample
然后,實例化自定義數據集類并使用torch.utils.data.DataLoader
創建一個數據加載器,指定批量大小和是否打亂數據。
data = [...] # 數據集
dataset = CustomDataset(data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
最后,可以通過迭代數據加載器來訪問數據集中的數據。
for batch in dataloader:
# 處理批量數據
pass