在PyTorch中,可以使用torchvision.datasets
模塊來加載常用的數據集。該模塊提供了對以下常用數據集的支持:
加載數據集的一般步驟如下:
from torchvision import datasets
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
這里的變換是將圖像轉換為張量,并進行歸一化處理。
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
這里的root
參數指定數據集的下載和存儲路徑,train
參數表示加載訓練集還是測試集,transform
參數指定對數據集進行的變換,download
參數表示是否下載數據集(僅在第一次運行時需要下載)。
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
這里的batch_size
參數指定每個批次的樣本數,shuffle
參數表示是否對數據進行隨機打亂。
通過上述步驟,就能夠加載和使用PyTorch中的數據集進行訓練和測試。