版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報或認(rèn)領(lǐng)
文檔簡介
第解決pytorchrnn變長輸入序列的問題輸入數(shù)據(jù)是長度不固定的序列數(shù)據(jù),主要講解兩個部分
1、Data.DataLoader的collate_fn用法,以及按batch進(jìn)行padding數(shù)據(jù)
2、pack_padded_sequence和pad_packed_sequence來處理變長序列
collate_fn
Dataloader的collate_fn參數(shù),定義數(shù)據(jù)處理和合并成batch的方式。
由于pack_padded_sequence用到的tensor必須按照長度從大到小排過序的,所以在Collate_fn中,需要完成兩件事,一是把當(dāng)前batch的樣本按照當(dāng)前batch最大長度進(jìn)行padding,二是將padding后的數(shù)據(jù)從大到小進(jìn)行排序。
defpad_tensor(vec,pad):
args:
vec-tensortopad
pad-thesizetopadto
return:
anewtensorpaddedto'pad'
returntorch.cat([vec,torch.zeros(pad-len(vec),dtype=torch.float)],dim=0).data.numpy()
classCollate:
avariantofcallate_fnthatpadsaccordingtothelongestsequencein
abatchofsequences
def__init__(self):
pass
def_collate(self,batch):
args:
batch-listof(tensor,label)
reutrn:
xs-atensorofallexamplesin'batch'beforepaddinglike:
[tensor([1,2,3,4]),
tensor([1,2]),
tensor([1,2,3,4,5])]
ys-aLongTensorofalllabelsinbatchlike:
[1,0,1]
xs=[torch.FloatTensor(v[0])forvinbatch]
ys=torch.LongTensor([v[1]forvinbatch])
#獲得每個樣本的序列長度
seq_lengths=torch.LongTensor([vforvinmap(len,xs)])
max_len=max([len(v)forvinxs])
#每個樣本都padding到當(dāng)前batch的最大長度
xs=torch.FloatTensor([pad_tensor(v,max_len)forvinxs])
#把xs和ys按照序列長度從大到小排序
seq_lengths,perm_idx=seq_lengths.sort(0,descending=True)
xs=xs[perm_idx]
ys=ys[perm_idx]
returnxs,seq_lengths,ys
def__call__(self,batch):
returnself._collate(batch)
定義完collate類以后,在DataLoader中直接使用
train_data=Data.DataLoader(dataset=train_dataset,batch_size=32,num_workers=0,collate_fn=Collate())
torch.nn.utils.rnn.pack_padded_sequence()
pack_padded_sequence將一個填充過的變長序列壓緊。輸入?yún)?shù)包括
input(Variable)-被填充過后的變長序列組成的batchdata
lengths(list[int])-變長序列的原始序列長度
batch_first(bool,optional)-如果是True,input的形狀應(yīng)該是(batch_size,seq_len,input_size)
返回值:一個PackedSequence對象,可以直接作為rnn,lstm,gru的傳入數(shù)據(jù)。
用法:
fromtorch.nn.utils.rnnimportpack_padded_sequence,pad_packed_sequence
#x是填充過后的batch數(shù)據(jù),seq_lengths是每個樣本的序列長度
packed_input=pack_padded_sequence(x,seq_lengths,batch_first=True)
RNN模型
定義了一個單向的LSTM模型,因為處理的是變長序列,forward函數(shù)傳入的值是一個PackedSequence對象,返回值也是一個PackedSequence對象
classModel(nn.Module):
def__init__(self,in_size,hid_size,n_layer,drop=0.1,bi=False):
super(Model,self).__init__()
self.lstm=nn.LSTM(input_size=in_size,
hidden_size=hid_size,
num_layers=n_layer,
batch_first=True,
dropout=drop,
bidirectional=bi)
#分類類別數(shù)目為2
self.fc=nn.Linear(in_features=hid_size,out_features=2)
defforward(self,x):
:paramx:變長序列時,x是一個PackedSequence對象
:return:PackedSequence對象
#lstm_out:tensorofshape(batch,seq_len,num_directions*hidden_size)
lstm_out,_=self.lstm(x)
returnlstm_out
model=Model()
lstm_out=model(packed_input)
torch.nn.utils.rnn.pad_packed_sequence()
這個操作和pack_padded_sequence()是相反的,把壓緊的序列再填充回來。因為前面提到的LSTM模型傳入和返回的都是PackedSequence對象,所以我們?nèi)绻胍逊祷氐腜ackedSequence對象轉(zhuǎn)換回Tensor,就需要用到pad_packed_sequence函數(shù)。
參數(shù)說明:
sequence(PackedSequence)–將要被填充的batch
batch_first(bool,optional)–如果為True,返回的數(shù)據(jù)的形狀為(batch_size,seq_len,input_size)
返回值:一個tuple,包含被填充后的序列,和batch中序列的長度列表。
用法:
#此處lstm_out是一個PackedSequence對象
output,_=pad_packed_sequence(lstm_out)
返回的output是一個形狀為(batch_size,seq_len,input_size)的tensor。
1、pytorch在自定義dataset時,可以在DataLoader的collate_fn參數(shù)中定義對數(shù)據(jù)的變換,操作以及合成batch的方式。
2、處理變長rnn問題時,通過pack_padded_sequence()將填充的batch數(shù)據(jù)轉(zhuǎn)換成PackedSequence對象,直接傳入rnn模型中。通過pad_packed_sequence()來將rnn模型輸出的PackedSequence對象轉(zhuǎn)換回相應(yīng)的Tensor。
補充:pytorch實現(xiàn)不定長輸入的RNN/LSTM/GRU
Asweallknow,RNN循環(huán)神經(jīng)網(wǎng)絡(luò)(及其改進(jìn)模型LSTM、GRU)可以處理序列的順序信息,如人類自然語言。但是在實際場景中,我們常常向模型輸入一個批次(batch)的數(shù)據(jù),這個批次中的每個序列往往不是等長的。
pytorch提供的模型(nn.RNN,nn.LSTM,nn.GRU)是支持可變長序列的處理的,但條件是傳入的數(shù)據(jù)必須按序列長度排序。本文針對以下兩種場景提出解決方法。
1、每個樣本只有一個序列:(seq,label),其中seq是一個長度不定的序列。則使用pytorch訓(xùn)練時,我們將按列把一個批次的數(shù)據(jù)輸入網(wǎng)絡(luò),seq這一列的形狀就是(batch_size,seq_len),經(jīng)過編碼層(如word2vec)之后的形狀是(batch_size,seq_len,emb_size)。
2、情況1的拓展:每個樣本有兩個(或多個)序列,如(seq1,seq2,label)。這種樣本形式在問答系統(tǒng)、推薦系統(tǒng)多見。
通用解決方案
定義ImprovedRnn類。與nn.RNN,nn.LSTM,nn.GRU相比,除了此兩點【①forward函數(shù)多一個參數(shù)lengths表示每個seq的長度】【②初始化函數(shù)(__init__)第一個參數(shù)module必須指定三者之一】外,使用方法完全相同。
importtorch
fromtorchimportnn
classImprovedRnn(nn.Module):
def__init__(self,module,*args,**kwargs):
assertmodulein(nn.RNN,nn.LSTM,nn.GRU)
super().__init__()
self.module=module(*args,**kwargs)
defforward(self,input,lengths):#inputshape(batch_size,seq_len,input_size)
ifnothasattr(self,'_flattened'):
self.module.flatten_parameters()
setattr(self,'_flattened',True)
max_len=input.shape[1]
#enforce_sorted=False則自動按lengths排序,并且返回值package.unsorted_indices可用于恢復(fù)原順序
package=nn.utils.rnn.pack_padded_sequence(input,lengths.cpu(),batch_first=self.module.batch_first,enforce_sorted=False)
result,hidden=self.module(package)
#total_length參數(shù)一般不需要,因為lengths列表中一般含最大值。但分布式訓(xùn)練時是將一個batch切分了,故一定要有!
result,lens=nn.utils.rnn.pad_packed_sequence(result,batch_first=self.module.batch_first,total_length=max_len)
returnresult[package.unsorted_indices],hidden#outputshape(batch_size,seq_len,rnn_hidden_size)
使用示例:
classTestNet(nn.Module):
def__init__(self,word_emb,gru_in,gru_out):
super().__init__()
self.encode=nn.Embedding.from_pretrained(torch.Tensor(word_emb))
self.rnn=ImprovedRnn(nn.RNN,input_size=gru_in,hidden_size=gru_out,
batch_first=True,bidirectional=True)
defforward(self,seq1,s
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
- 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 小學(xué)防校園欺凌培訓(xùn)制度
- 公司法務(wù)培訓(xùn)體系制度
- 駕校教練員安全培訓(xùn)制度
- 培訓(xùn)教學(xué)質(zhì)量控制制度
- 美術(shù)培訓(xùn)機構(gòu)人事制度
- 旅社消防安全培訓(xùn)制度
- 血液凈化室技師培訓(xùn)制度
- 殘疾培訓(xùn)學(xué)員管理制度
- 后勤員工培訓(xùn)管理制度
- 停車泊位收費員培訓(xùn)制度
- 《實踐論》《矛盾論》導(dǎo)讀課件
- 中試基地運營管理制度
- 老年病康復(fù)訓(xùn)練治療講課件
- DB4201-T 617-2020 武漢市架空管線容貌管理技術(shù)規(guī)范
- 藥品追溯碼管理制度
- 腳手架國際化標(biāo)準(zhǔn)下的發(fā)展趨勢
- 購銷合同范本(塘渣)8篇
- 生鮮業(yè)務(wù)采購合同協(xié)議
- GB/T 4340.2-2025金屬材料維氏硬度試驗第2部分:硬度計的檢驗與校準(zhǔn)
- 銷售合同評審管理制度
- 資產(chǎn)評估員工管理制度
評論
0/150
提交評論