版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認領(lǐng)
文檔簡介
第Pytorch中DataLoader的使用方法詳解目錄一:dataset類構(gòu)建。二:DataLoader使用三:舉例前言加載數(shù)據(jù)datasetdataloader在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數(shù)據(jù)集的兩個函數(shù),用來處理加載數(shù)據(jù)集。通常情況下,使用的關(guān)鍵在于構(gòu)建dataset類。
一:dataset類構(gòu)建。
在構(gòu)建數(shù)據(jù)集類時,除了__init__(self),還要有__len__(self)與__getitem__(self,item)兩個方法,這三個是必不可少的,至于其它用于數(shù)據(jù)處理的函數(shù),可以任意定義。
classdataset:
def__init__(self,...):
def__len__(self,...):
returnn
def__getitem__(self,item):
returndata[item]
正常情況下,該數(shù)據(jù)集是要繼承Pytorch中Dataset類的,但實際操作中,即使不繼承,數(shù)據(jù)集類構(gòu)建后仍可以用Dataloader()加載的。
在dataset類中,__len__(self)返回數(shù)據(jù)集中數(shù)據(jù)個數(shù),__getitem__(self,item)表示每次返回第item條數(shù)據(jù)。
二:DataLoader使用
在構(gòu)建dataset類后,即可使用DataLoader加載。DataLoader中常用參數(shù)如下:
1.dataset:需要載入的數(shù)據(jù)集,如前面構(gòu)造的dataset類。
2.batch_size:批大小,在神經(jīng)網(wǎng)絡(luò)訓(xùn)練時我們很少逐條數(shù)據(jù)訓(xùn)練,而是幾條數(shù)據(jù)作為一個batch進行訓(xùn)練。
3.shuffle:是否在打亂數(shù)據(jù)集樣本順序。True為打亂,F(xiàn)alse反之。
4.drop_last:是否舍去最后一個batch的數(shù)據(jù)(很多情況下數(shù)據(jù)總數(shù)N與batchsize不整除,導(dǎo)致最后一個batch不為batchsize)。True為舍去,F(xiàn)alse反之。
三:舉例
兔兔以指標(biāo)為1,數(shù)據(jù)個數(shù)為100的數(shù)據(jù)為例。
importtorch
fromtorch.utils.dataimportDataLoader
classdataset:
def__init__(self):
self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
self.y=(torch.sin(self.x)+1)/2
def__len__(self):
return100
def__getitem__(self,item):
returnself.x[item],self.y[item]
data=DataLoader(dataset(),batch_size=10,shuffle=True)
forbatchindata:
print(batch)
當(dāng)然,利用這個數(shù)據(jù)集可以進行簡單的神經(jīng)網(wǎng)絡(luò)訓(xùn)練。
fromtorchimportnn
data=DataLoader(dataset(),batch_size=10,shuffle=True)
bp=nn.Sequential(nn.Linear(1,5),
nn.Sigmoid(),
nn.Linear(5,1),
nn.Sigmoid())
optim=torch.optim.Adam(params=bp.parameters())
Loss=nn.MSELoss()
forepochinrange(10):
print('the{}epoch'.format(epoch))
forbatchindata:
yp=bp(batch[0])
loss=Loss(yp,batch[1])
optim.zero_grad()
loss.backward()
optim.step()
ps:下面再給大家補充介紹下Pytorch中DataLoader的使用。
前言
最近開始接觸pytorch,從跑別人寫好的代碼開始,今天需要把輸入數(shù)據(jù)根據(jù)每個batch的最長輸入數(shù)據(jù),填充到一樣的長度(之前是將所有的數(shù)據(jù)直接填充到一樣的長度再輸入)。
剛開始是想偷懶,沒有去認真了解輸入的機制,結(jié)果一直報錯還是要認真學(xué)習(xí)呀!
加載數(shù)據(jù)
pytorch中加載數(shù)據(jù)的順序是:
①創(chuàng)建一個dataset對象
②創(chuàng)建一個dataloader對象
③循環(huán)dataloader對象,將data,label拿到模型中去訓(xùn)練
dataset
你需要自己定義一個class,里面至少包含3個函數(shù):
①__init__:傳入數(shù)據(jù),或者像下面一樣直接在函數(shù)里加載數(shù)據(jù)
②__len__:返回這個數(shù)據(jù)集一共有多少個item
③__getitem__:返回一條訓(xùn)練數(shù)據(jù),并將其轉(zhuǎn)換成tensor
importtorch
fromtorch.utils.dataimportDataset
classMydata(Dataset):
def__init__(self):
a=np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True)
b=np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True)
d=np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True)
c=np.load("D:/Python/nlp/NRE/c.npy")
self.x=list(zip(a,b,d,c))
def__getitem__(self,idx):
assertidxlen(self.x)
returnself.x[idx]
def__len__(self):
returnlen(self.x)
dataloader
參數(shù):
dataset:傳入的數(shù)據(jù)
shuffle=True:是否打亂數(shù)據(jù)
collate_fn:使用這個參數(shù)可以自己操作每個batch的數(shù)據(jù)
dataset=Mydata()
dataloader=DataLoader(dataset,batch_size=2,shuffle=True,collate_fn=mycollate)
下面是將每個batch的數(shù)據(jù)填充到該batch的最大長度
defmycollate(data):
a=[]
b=[]
c=[]
d=[]
max_len=len(data[0][0])
foriindata:
iflen(i[0])max_len:
max_len=len(i[0])
iflen(i[1])max_len:
max_len=len(i[1])
iflen(i[2])max_len:
max_len=len(i[2])
print(max_len)
#填充
foriindata:
iflen(i[0])max_len:
i[0].extend([27]*(max_len-len(i[0])))
iflen(i[1])max_len:
i[1].extend([27]*(max_len-len(i[1])))
iflen(i[2])max_len:
i[2].extend([27]*(max_len-len(i[2])))
a.append(i[0])
b.append(i[1])
d.append(i[2])
c.extend(i[3])
#這里要自己轉(zhuǎn)成tensor
a=torch.Tensor(a)
b=torch.Tensor(b)
c=torch.Tensor(c)
d=torch.Tensor(d)
data1=[a,b,d,c]
print("data1",data1)
returndata1
結(jié)果:
最后循環(huán)該dataloader,拿到數(shù)據(jù)放入模型進行訓(xùn)練:
forii,datainenumerate(test_data_loader):
ifopt.use_gpu:
data=list(map(lambdax:torch.LongTensor(x.long()).cuda(),data))
else:
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責(zé)。
- 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 滑雪安全培訓(xùn)內(nèi)容課件
- 2025年晉城市政公司筆試及答案
- 滑雪介紹課件
- 滑模施工培訓(xùn)課件
- 滑縣安全培訓(xùn)生產(chǎn)課件
- 溶洞安全培訓(xùn)課件
- 2025 小學(xué)六年級數(shù)學(xué)上冊分數(shù)除法旅行費用計算課件
- 高鐵客運安全課件教學(xué)
- 未來五年交流電牽引采煤機企業(yè)縣域市場拓展與下沉戰(zhàn)略分析研究報告
- 未來五年招待所住宿市場需求變化趨勢與商業(yè)創(chuàng)新機遇分析研究報告
- 2026年中國數(shù)聯(lián)物流備考題庫有限公司招聘備考題庫及參考答案詳解一套
- 四川省樂山市2026屆高一上數(shù)學(xué)期末質(zhì)量檢測試題含解析
- 2025年天津中德應(yīng)用技術(shù)大學(xué)馬克思主義基本原理概論期末考試真題匯編
- 2025青海省交通控股集團有限公司面向社會公開招聘70人筆試歷年參考題庫附帶答案詳解
- 韓語興趣愛好課件
- 青霉素過敏性休克處理
- 70周歲換證三力測試題,老人駕考模擬測試題
- 工地清場協(xié)議書
- 2026年包頭輕工職業(yè)技術(shù)學(xué)院單招職業(yè)適應(yīng)性測試題庫附答案詳解
- 2026年及未來5年市場數(shù)據(jù)中國內(nèi)貿(mào)集裝箱行業(yè)全景評估及投資規(guī)劃建議報告
- 2025貴州鹽業(yè)(集團)有限責(zé)任公司貴陽分公司招聘筆試考試備考題庫及答案解析
評論
0/150
提交評論