您好,登錄后才能下訂單哦!
這篇文章主要介紹“Pytorch如何繼承Subset類完成自定義數據拆分”,在日常操作中,相信很多人在Pytorch如何繼承Subset類完成自定義數據拆分問題上存在疑惑,小編查閱了各式資料,整理出簡單好用的操作方法,希望對大家解答”Pytorch如何繼承Subset類完成自定義數據拆分”的疑惑有所幫助!接下來,請跟著小編一起來學習吧!
下面是加載內置訓練數據集的常見操作:
from torchvision.datasets import FashionMNIST from torchvision.transforms import Compose, ToTensor, Normalize RAW_DATA_PATH = './rawdata' transform = Compose( [ToTensor(), Normalize((0.1307,), (0.3081,)) ] ) train_data = FashionMNIST( root=RAW_DATA_PATH, download=True, train=True, transform=transform )
這里的train_data
做為 dataset
對象,它擁有許多熟悉,我們可以通過以下方法獲取樣本數據的分類類別集合、樣本的特征維度、樣本的標簽集合等信息。
classes = train_data.classes num_features = train_data.data[0].shape[0] train_labels = train_data.targets print(classes) print(num_features) print(train_labels)
輸出如下:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
28
tensor([9, 0, 0, ..., 3, 0, 5])
但是,我們常常會在訓練集的基礎上拆分出驗證集(或者只用部分數據來進行訓練)。我們想到的第一個方法是使用 torch.utils.data.random_split
對 dataset
進行劃分,下面我們假設劃分10000個樣本做為訓練集,其余樣本做為驗證集:
from torch.utils.data import random_split k = 10000 train_data, valid_data = random_split(train_data, [k, len(train_data)-k])
注意我們如果打印 train_data 和 valid_data
的類型,可以看到顯示:
<class 'torch.utils.data.dataset.Subset'>
已經不再是torchvision.datasets.mnist.FashionMNIST
對象,而是一個所謂的 Subset 對象!此時 Subset 對象雖然仍然還存有 data 屬性,但是內置的 target
和 classes
屬性已經不復存在,
比如如果我們強行訪問 valid_data 的 target 屬性:
valid_target = valid_data.target
就會報如下錯誤:
'Subset' object has no attribute 'target'
但如果我們在后續的代碼中常常會將拆分后的數據集也默認為 dataset 對象,那么該如何做到代碼的一致性呢?
這里有一個trick,那就是以繼承 SubSet 類的方式的方式定義一個新的 CustomSubSet 類,使新類在保持 SubSet 類的基本屬性的基礎上,擁有和原本數據集類相似的屬性,如 targets
和 classes
等:
from torch.utils.data import Subset class CustomSubset(Subset): '''A custom subset class''' def __init__(self, dataset, indices): super().__init__(dataset, indices) self.targets = dataset.targets # 保留targets屬性 self.classes = dataset.classes # 保留classes屬性 def __getitem__(self, idx): #同時支持索引訪問操作 x, y = self.dataset[self.indices[idx]] return x, y def __len__(self): # 同時支持取長度操作 return len(self.indices)
然后就引出了第二種劃分方法,即通過初始化 CustomSubset
對象的方式直接對數據集進行劃分(這里為了簡化省略了shuffle的步驟):
import numpy as np from copy import deepcopy origin_data = deepcopy(train_data) train_data = CustomSubset(origin_data, np.arange(k)) valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)
注意: CustomSubset
類的初始化方法的第二個參數 indices 為樣本索引,我們可以通過 np.arange()
的方法來創建。
然后,我們再訪問 valid_data 對應的 classes 和 targes 屬性:
print(valid_data.classes) print(valid_data.targets)
此時,我們發現可以成功訪問這些屬性了:
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] tensor([9, 0, 0, ..., 3, 0, 5])
當然, CustomSubset
的作用并不只是添加數據集的屬性,我們還可以自定義一些數據預處理操作。
我們將類的結構修改如下:
class CustomSubset(Subset): '''A custom subset class with customizable data transformation''' def __init__(self, dataset, indices, subset_transform=None): super().__init__(dataset, indices) self.targets = dataset.targets self.classes = dataset.classes self.subset_transform = subset_transform def __getitem__(self, idx): x, y = self.dataset[self.indices[idx]] if self.subset_transform: x = self.subset_transform(x) return x, y def __len__(self): return len(self.indices)
我們可以在使用樣本前設置好數據預處理算子:
from torchvision import transforms valid_data.subset_transform = transforms.Compose(\ [transforms.RandomRotation((180,180))])
這樣,我們再像下列這樣用索引訪問取出數據集樣本時,就會自動調用算子完成預處理操作:
print(valid_data[0])
打印結果縮略如下:
(tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)
到此,關于“Pytorch如何繼承Subset類完成自定義數據拆分”的學習就結束了,希望能夠解決大家的疑惑。理論與實踐的搭配能更好的幫助大家學習,快去試試吧!若想繼續學習更多相關知識,請繼續關注億速云網站,小編會繼續努力為大家帶來更多實用的文章!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。