解決pytorch rnn 變長輸入序列的問題_第1頁
解決pytorch rnn 變長輸入序列的問題_第2頁
解決pytorch rnn 變長輸入序列的問題_第3頁
解決pytorch rnn 變長輸入序列的問題_第4頁
解決pytorch rnn 變長輸入序列的問題_第5頁
已閱讀5頁,還剩3頁未讀, 繼續(xù)免費閱讀

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報或認(rèn)領(lǐng)

文檔簡介

第解決pytorchrnn變長輸入序列的問題輸入數(shù)據(jù)是長度不固定的序列數(shù)據(jù),主要講解兩個部分

1、Data.DataLoader的collate_fn用法,以及按batch進(jìn)行padding數(shù)據(jù)

2、pack_padded_sequence和pad_packed_sequence來處理變長序列

collate_fn

Dataloader的collate_fn參數(shù),定義數(shù)據(jù)處理和合并成batch的方式。

由于pack_padded_sequence用到的tensor必須按照長度從大到小排過序的,所以在Collate_fn中,需要完成兩件事,一是把當(dāng)前batch的樣本按照當(dāng)前batch最大長度進(jìn)行padding,二是將padding后的數(shù)據(jù)從大到小進(jìn)行排序。

defpad_tensor(vec,pad):

args:

vec-tensortopad

pad-thesizetopadto

return:

anewtensorpaddedto'pad'

returntorch.cat([vec,torch.zeros(pad-len(vec),dtype=torch.float)],dim=0).data.numpy()

classCollate:

avariantofcallate_fnthatpadsaccordingtothelongestsequencein

abatchofsequences

def__init__(self):

pass

def_collate(self,batch):

args:

batch-listof(tensor,label)

reutrn:

xs-atensorofallexamplesin'batch'beforepaddinglike:

[tensor([1,2,3,4]),

tensor([1,2]),

tensor([1,2,3,4,5])]

ys-aLongTensorofalllabelsinbatchlike:

[1,0,1]

xs=[torch.FloatTensor(v[0])forvinbatch]

ys=torch.LongTensor([v[1]forvinbatch])

#獲得每個樣本的序列長度

seq_lengths=torch.LongTensor([vforvinmap(len,xs)])

max_len=max([len(v)forvinxs])

#每個樣本都padding到當(dāng)前batch的最大長度

xs=torch.FloatTensor([pad_tensor(v,max_len)forvinxs])

#把xs和ys按照序列長度從大到小排序

seq_lengths,perm_idx=seq_lengths.sort(0,descending=True)

xs=xs[perm_idx]

ys=ys[perm_idx]

returnxs,seq_lengths,ys

def__call__(self,batch):

returnself._collate(batch)

定義完collate類以后,在DataLoader中直接使用

train_data=Data.DataLoader(dataset=train_dataset,batch_size=32,num_workers=0,collate_fn=Collate())

torch.nn.utils.rnn.pack_padded_sequence()

pack_padded_sequence將一個填充過的變長序列壓緊。輸入?yún)?shù)包括

input(Variable)-被填充過后的變長序列組成的batchdata

lengths(list[int])-變長序列的原始序列長度

batch_first(bool,optional)-如果是True,input的形狀應(yīng)該是(batch_size,seq_len,input_size)

返回值:一個PackedSequence對象,可以直接作為rnn,lstm,gru的傳入數(shù)據(jù)。

用法:

fromtorch.nn.utils.rnnimportpack_padded_sequence,pad_packed_sequence

#x是填充過后的batch數(shù)據(jù),seq_lengths是每個樣本的序列長度

packed_input=pack_padded_sequence(x,seq_lengths,batch_first=True)

RNN模型

定義了一個單向的LSTM模型,因為處理的是變長序列,forward函數(shù)傳入的值是一個PackedSequence對象,返回值也是一個PackedSequence對象

classModel(nn.Module):

def__init__(self,in_size,hid_size,n_layer,drop=0.1,bi=False):

super(Model,self).__init__()

self.lstm=nn.LSTM(input_size=in_size,

hidden_size=hid_size,

num_layers=n_layer,

batch_first=True,

dropout=drop,

bidirectional=bi)

#分類類別數(shù)目為2

self.fc=nn.Linear(in_features=hid_size,out_features=2)

defforward(self,x):

:paramx:變長序列時,x是一個PackedSequence對象

:return:PackedSequence對象

#lstm_out:tensorofshape(batch,seq_len,num_directions*hidden_size)

lstm_out,_=self.lstm(x)

returnlstm_out

model=Model()

lstm_out=model(packed_input)

torch.nn.utils.rnn.pad_packed_sequence()

這個操作和pack_padded_sequence()是相反的,把壓緊的序列再填充回來。因為前面提到的LSTM模型傳入和返回的都是PackedSequence對象,所以我們?nèi)绻胍逊祷氐腜ackedSequence對象轉(zhuǎn)換回Tensor,就需要用到pad_packed_sequence函數(shù)。

參數(shù)說明:

sequence(PackedSequence)–將要被填充的batch

batch_first(bool,optional)–如果為True,返回的數(shù)據(jù)的形狀為(batch_size,seq_len,input_size)

返回值:一個tuple,包含被填充后的序列,和batch中序列的長度列表。

用法:

#此處lstm_out是一個PackedSequence對象

output,_=pad_packed_sequence(lstm_out)

返回的output是一個形狀為(batch_size,seq_len,input_size)的tensor。

1、pytorch在自定義dataset時,可以在DataLoader的collate_fn參數(shù)中定義對數(shù)據(jù)的變換,操作以及合成batch的方式。

2、處理變長rnn問題時,通過pack_padded_sequence()將填充的batch數(shù)據(jù)轉(zhuǎn)換成PackedSequence對象,直接傳入rnn模型中。通過pad_packed_sequence()來將rnn模型輸出的PackedSequence對象轉(zhuǎn)換回相應(yīng)的Tensor。

補充:pytorch實現(xiàn)不定長輸入的RNN/LSTM/GRU

Asweallknow,RNN循環(huán)神經(jīng)網(wǎng)絡(luò)(及其改進(jìn)模型LSTM、GRU)可以處理序列的順序信息,如人類自然語言。但是在實際場景中,我們常常向模型輸入一個批次(batch)的數(shù)據(jù),這個批次中的每個序列往往不是等長的。

pytorch提供的模型(nn.RNN,nn.LSTM,nn.GRU)是支持可變長序列的處理的,但條件是傳入的數(shù)據(jù)必須按序列長度排序。本文針對以下兩種場景提出解決方法。

1、每個樣本只有一個序列:(seq,label),其中seq是一個長度不定的序列。則使用pytorch訓(xùn)練時,我們將按列把一個批次的數(shù)據(jù)輸入網(wǎng)絡(luò),seq這一列的形狀就是(batch_size,seq_len),經(jīng)過編碼層(如word2vec)之后的形狀是(batch_size,seq_len,emb_size)。

2、情況1的拓展:每個樣本有兩個(或多個)序列,如(seq1,seq2,label)。這種樣本形式在問答系統(tǒng)、推薦系統(tǒng)多見。

通用解決方案

定義ImprovedRnn類。與nn.RNN,nn.LSTM,nn.GRU相比,除了此兩點【①forward函數(shù)多一個參數(shù)lengths表示每個seq的長度】【②初始化函數(shù)(__init__)第一個參數(shù)module必須指定三者之一】外,使用方法完全相同。

importtorch

fromtorchimportnn

classImprovedRnn(nn.Module):

def__init__(self,module,*args,**kwargs):

assertmodulein(nn.RNN,nn.LSTM,nn.GRU)

super().__init__()

self.module=module(*args,**kwargs)

defforward(self,input,lengths):#inputshape(batch_size,seq_len,input_size)

ifnothasattr(self,'_flattened'):

self.module.flatten_parameters()

setattr(self,'_flattened',True)

max_len=input.shape[1]

#enforce_sorted=False則自動按lengths排序,并且返回值package.unsorted_indices可用于恢復(fù)原順序

package=nn.utils.rnn.pack_padded_sequence(input,lengths.cpu(),batch_first=self.module.batch_first,enforce_sorted=False)

result,hidden=self.module(package)

#total_length參數(shù)一般不需要,因為lengths列表中一般含最大值。但分布式訓(xùn)練時是將一個batch切分了,故一定要有!

result,lens=nn.utils.rnn.pad_packed_sequence(result,batch_first=self.module.batch_first,total_length=max_len)

returnresult[package.unsorted_indices],hidden#outputshape(batch_size,seq_len,rnn_hidden_size)

使用示例:

classTestNet(nn.Module):

def__init__(self,word_emb,gru_in,gru_out):

super().__init__()

self.encode=nn.Embedding.from_pretrained(torch.Tensor(word_emb))

self.rnn=ImprovedRnn(nn.RNN,input_size=gru_in,hidden_size=gru_out,

batch_first=True,bidirectional=True)

defforward(self,seq1,s

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論