Pytorch中DataLoader的使用方法詳解_第1頁
Pytorch中DataLoader的使用方法詳解_第2頁
Pytorch中DataLoader的使用方法詳解_第3頁
Pytorch中DataLoader的使用方法詳解_第4頁
Pytorch中DataLoader的使用方法詳解_第5頁
已閱讀5頁,還剩2頁未讀, 繼續(xù)免費閱讀

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認領(lǐng)

文檔簡介

第Pytorch中DataLoader的使用方法詳解目錄一:dataset類構(gòu)建。二:DataLoader使用三:舉例前言加載數(shù)據(jù)datasetdataloader在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數(shù)據(jù)集的兩個函數(shù),用來處理加載數(shù)據(jù)集。通常情況下,使用的關(guān)鍵在于構(gòu)建dataset類。

一:dataset類構(gòu)建。

在構(gòu)建數(shù)據(jù)集類時,除了__init__(self),還要有__len__(self)與__getitem__(self,item)兩個方法,這三個是必不可少的,至于其它用于數(shù)據(jù)處理的函數(shù),可以任意定義。

classdataset:

def__init__(self,...):

def__len__(self,...):

returnn

def__getitem__(self,item):

returndata[item]

正常情況下,該數(shù)據(jù)集是要繼承Pytorch中Dataset類的,但實際操作中,即使不繼承,數(shù)據(jù)集類構(gòu)建后仍可以用Dataloader()加載的。

在dataset類中,__len__(self)返回數(shù)據(jù)集中數(shù)據(jù)個數(shù),__getitem__(self,item)表示每次返回第item條數(shù)據(jù)。

二:DataLoader使用

在構(gòu)建dataset類后,即可使用DataLoader加載。DataLoader中常用參數(shù)如下:

1.dataset:需要載入的數(shù)據(jù)集,如前面構(gòu)造的dataset類。

2.batch_size:批大小,在神經(jīng)網(wǎng)絡(luò)訓(xùn)練時我們很少逐條數(shù)據(jù)訓(xùn)練,而是幾條數(shù)據(jù)作為一個batch進行訓(xùn)練。

3.shuffle:是否在打亂數(shù)據(jù)集樣本順序。True為打亂,F(xiàn)alse反之。

4.drop_last:是否舍去最后一個batch的數(shù)據(jù)(很多情況下數(shù)據(jù)總數(shù)N與batchsize不整除,導(dǎo)致最后一個batch不為batchsize)。True為舍去,F(xiàn)alse反之。

三:舉例

兔兔以指標(biāo)為1,數(shù)據(jù)個數(shù)為100的數(shù)據(jù)為例。

importtorch

fromtorch.utils.dataimportDataLoader

classdataset:

def__init__(self):

self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)

self.y=(torch.sin(self.x)+1)/2

def__len__(self):

return100

def__getitem__(self,item):

returnself.x[item],self.y[item]

data=DataLoader(dataset(),batch_size=10,shuffle=True)

forbatchindata:

print(batch)

當(dāng)然,利用這個數(shù)據(jù)集可以進行簡單的神經(jīng)網(wǎng)絡(luò)訓(xùn)練。

fromtorchimportnn

data=DataLoader(dataset(),batch_size=10,shuffle=True)

bp=nn.Sequential(nn.Linear(1,5),

nn.Sigmoid(),

nn.Linear(5,1),

nn.Sigmoid())

optim=torch.optim.Adam(params=bp.parameters())

Loss=nn.MSELoss()

forepochinrange(10):

print('the{}epoch'.format(epoch))

forbatchindata:

yp=bp(batch[0])

loss=Loss(yp,batch[1])

optim.zero_grad()

loss.backward()

optim.step()

ps:下面再給大家補充介紹下Pytorch中DataLoader的使用。

前言

最近開始接觸pytorch,從跑別人寫好的代碼開始,今天需要把輸入數(shù)據(jù)根據(jù)每個batch的最長輸入數(shù)據(jù),填充到一樣的長度(之前是將所有的數(shù)據(jù)直接填充到一樣的長度再輸入)。

剛開始是想偷懶,沒有去認真了解輸入的機制,結(jié)果一直報錯還是要認真學(xué)習(xí)呀!

加載數(shù)據(jù)

pytorch中加載數(shù)據(jù)的順序是:

①創(chuàng)建一個dataset對象

②創(chuàng)建一個dataloader對象

③循環(huán)dataloader對象,將data,label拿到模型中去訓(xùn)練

dataset

你需要自己定義一個class,里面至少包含3個函數(shù):

①__init__:傳入數(shù)據(jù),或者像下面一樣直接在函數(shù)里加載數(shù)據(jù)

②__len__:返回這個數(shù)據(jù)集一共有多少個item

③__getitem__:返回一條訓(xùn)練數(shù)據(jù),并將其轉(zhuǎn)換成tensor

importtorch

fromtorch.utils.dataimportDataset

classMydata(Dataset):

def__init__(self):

a=np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True)

b=np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True)

d=np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True)

c=np.load("D:/Python/nlp/NRE/c.npy")

self.x=list(zip(a,b,d,c))

def__getitem__(self,idx):

assertidxlen(self.x)

returnself.x[idx]

def__len__(self):

returnlen(self.x)

dataloader

參數(shù):

dataset:傳入的數(shù)據(jù)

shuffle=True:是否打亂數(shù)據(jù)

collate_fn:使用這個參數(shù)可以自己操作每個batch的數(shù)據(jù)

dataset=Mydata()

dataloader=DataLoader(dataset,batch_size=2,shuffle=True,collate_fn=mycollate)

下面是將每個batch的數(shù)據(jù)填充到該batch的最大長度

defmycollate(data):

a=[]

b=[]

c=[]

d=[]

max_len=len(data[0][0])

foriindata:

iflen(i[0])max_len:

max_len=len(i[0])

iflen(i[1])max_len:

max_len=len(i[1])

iflen(i[2])max_len:

max_len=len(i[2])

print(max_len)

#填充

foriindata:

iflen(i[0])max_len:

i[0].extend([27]*(max_len-len(i[0])))

iflen(i[1])max_len:

i[1].extend([27]*(max_len-len(i[1])))

iflen(i[2])max_len:

i[2].extend([27]*(max_len-len(i[2])))

a.append(i[0])

b.append(i[1])

d.append(i[2])

c.extend(i[3])

#這里要自己轉(zhuǎn)成tensor

a=torch.Tensor(a)

b=torch.Tensor(b)

c=torch.Tensor(c)

d=torch.Tensor(d)

data1=[a,b,d,c]

print("data1",data1)

returndata1

結(jié)果:

最后循環(huán)該dataloader,拿到數(shù)據(jù)放入模型進行訓(xùn)練:

forii,datainenumerate(test_data_loader):

ifopt.use_gpu:

data=list(map(lambdax:torch.LongTensor(x.long()).cuda(),data))

else:

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論