Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼_第1頁
Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼_第2頁
Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼_第3頁
Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼_第4頁
Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼_第5頁
已閱讀5頁,還剩2頁未讀, 繼續(xù)免費(fèi)閱讀

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報(bào)或認(rèn)領(lǐng)

文檔簡介

第Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼目錄1.乘法運(yùn)算總覽2.乘法算子實(shí)現(xiàn)2.1矩陣乘算子實(shí)現(xiàn)2.2點(diǎn)乘算子實(shí)現(xiàn)本文介紹一下Pytorch中常用乘法的TensorRT實(shí)現(xiàn)。

pytorch用于訓(xùn)練,TensorRT用于推理是很多AI應(yīng)用開發(fā)的標(biāo)配。大家往往更加熟悉pytorch的算子,而不太熟悉TensorRT的算子,這里拿比較常用的乘法運(yùn)算在兩種框架下的實(shí)現(xiàn)做一個(gè)對比,可能會(huì)有更加直觀一些的認(rèn)識。

1.乘法運(yùn)算總覽

先把pytorch中的一些常用的乘法運(yùn)算進(jìn)行一個(gè)總覽:

torch.mm:用于兩個(gè)矩陣(不包括向量)的乘法,如維度(m,n)的矩陣乘以維度(n,p)的矩陣;torch.bmm:用于帶batch的三維向量的乘法,如維度(b,m,n)的矩陣乘以維度(b,n,p)的矩陣;torch.mul:用于同維度矩陣的逐像素點(diǎn)相乘,也即點(diǎn)乘,如維度(m,n)的矩陣點(diǎn)乘維度(m,n)的矩陣。該方法支持廣播,也即支持矩陣和元素點(diǎn)乘;torch.mv:用于矩陣和向量的乘法,矩陣在前,向量在后,如維度(m,n)的矩陣乘以維度為(n)的向量,輸出維度為(m);torch.matmul:用于兩個(gè)張量相乘,或矩陣與向量乘法,作用包含torch.mm、torch.bmm、torch.mv;@:作用相當(dāng)于torch.matmul;*:作用相當(dāng)于torch.mul;

如上進(jìn)行了一些具體羅列,可以歸納出,常用的乘法無非兩種:矩陣乘和點(diǎn)乘,所以下面分這兩類進(jìn)行介紹。

2.乘法算子實(shí)現(xiàn)

2.1矩陣乘算子實(shí)現(xiàn)

先來看看矩陣乘法的pytorch的實(shí)現(xiàn)(以下實(shí)現(xiàn)在終端):

importtorch

#torch.mm

a=torch.randn(66,99)

b=torch.randn(99,88)

c=torch.mm(a,b)

c.shape

torch.size([66,88])

#torch.bmm

a=torch.randn(3,66,99)

b=torch.randn(3,99,77)

c=torch.bmm(a,b)

c.shape

torch.size([3,66,77])

#torch.mv

a=torch.randn(66,99)

b=torch.randn(99)

c=torch.mv(a,b)

c.shape

torch.size([66])

#torch.matmul

a=torch.randn(32,3,66,99)

b=torch.randn(32,3,99,55)

c=torch.matmul(a,b)

c.shape

torch.size([32,3,66,55])

d=a@b

d.shape

torch.size([32,3,66,55])

來看TensorRT的實(shí)現(xiàn),以上乘法都可使用addMatrixMultiply方法覆蓋,對應(yīng)torch.matmul,先來看該方法的定義:

//!

//!\briefAddaMatrixMultiplylayertothenetwork.

//!\paraminput0Thefirstinputtensor(commonlyA).

//!\paramop0Theoperationtoapplytoinput0.

//!\paraminput1Thesecondinputtensor(commonlyB).

//!\paramop1Theoperationtoapplytoinput1.

//!\seeIMatrixMultiplyLayer

//!\warningInt32tensorsarenotvalidinputtensors.

//!\returnThenewmatrixmultiplylayer,ornullptrifitcouldnotbecreated.

IMatrixMultiplyLayer*addMatrixMultiply(

ITensorinput0,MatrixOperationop0,ITensorinput1,MatrixOperationop1)noexcept

returnmImpl-addMatrixMultiply(input0,op0,input1,op1);

可以看到這個(gè)方法有四個(gè)傳參,對應(yīng)兩個(gè)張量和其operation。來看這個(gè)算子在TensorRT中怎么添加:

//構(gòu)造張量Tensor0

nvinfer1::IConstantLayer*Constant_layer0=m_network-addConstant(tensorShape0,value0);

//構(gòu)造張量Tensor1

nvinfer1::IConstantLayer*Constant_layer1=m_network-addConstant(tensorShape1,value1);

//添加矩陣乘法

nvinfer1::IMatrixMultiplyLayer*Matmul_layer=m_network-addMatrixMultiply(Constant_layer0-getOutput(0),matrix0Type,Constant_layer1-getOutput(0),matrix2Type);

//獲取輸出

matmulOutput=Matmul_layer-getOputput(0);

2.2點(diǎn)乘算子實(shí)現(xiàn)

再來看看點(diǎn)乘的pytorch的實(shí)現(xiàn)(以下實(shí)現(xiàn)在終端):

importtorch

#torch.mul

a=torch.randn(66,99)

b=torch.randn(66,99)

c=torch.mul(a,b)

c.shape

torch.size([66,99])

d=0.125

e=torch.mul(a,d)

e.shape

torch.size([66,99])

f=a*b

f.shape

torch.size([66,99])

來看TensorRT的實(shí)現(xiàn),以上乘法都可使用addScale方法覆蓋,這在圖像預(yù)處理中十分常用,先來看該方法的定義:

//!

//!\briefAddaScalelayertothenetwork.

//!\paraminputTheinputtensortothelayer.

//!Thistensorisrequiredtohaveaminimumof3dimensionsinimplicitbatchmode

//!andaminimumof4dimensionsinexplicitbatchmode.

//!\parammodeThescalingmode.

//!\paramshiftTheshiftvalue.

//!\paramscaleThescalevalue.

//!\parampowerThepowervalue.

//!Iftheweightsareavailable,thenthesizeofweightsaredependentontheScaleMode.

//!For::kUNIFORM,thenumberofweightsequals1.

//!For::kCHANNEL,thenumberofweightsequalsthechanneldimension.

//!For::kELEMENTWISE,thenumberofweightsequalstheproductofthelastthreedimensionsoftheinput.

//!\seeaddScaleNd

//!\seeIScaleLayer

//!\warningInt32tensorsarenotvalidinputtensors.

//!\returnThenewScalelayer,ornullptrifitcouldnotbecreated.

IScaleLayer*addScale(ITensorinput,ScaleModemode,Weightsshift,Weightsscale,Weightspower)noexcept

returnmImpl-addScale(input,mode,shift,scale,power);

可以看到有三個(gè)模式:

kUNIFORM:weights為一個(gè)值,對應(yīng)張量乘一個(gè)元素;kCHANNEL:weights維度和輸入張量通道的c維度對應(yīng),可以做一些以通道為基準(zhǔn)的預(yù)處理;kELEMENTWISE:weights維度和輸入張量的c、h、w對應(yīng),不考慮batch,所以是輸入的后三維;

再來看這個(gè)算子在TensorRT中怎么添加:

//構(gòu)造張量input

nvinfer1::IConstantLayer*Constant_layer=m_network-addConstant(tensorShape,value);

//scalemode選擇,kUNIFORM、kCHANNEL、kELEMENTWISE

scalemode=kUNIFORM;

//構(gòu)建Weights類型的shift、scale、power,其中volume為元素?cái)?shù)量

nvinfer1::WeightsscaleShift{nvinfer1::DataType::kFLOAT,nullptr,volume};

nvinfer1::WeightsscaleScale{nvinfer1::DataType::kFLOAT,nullptr,volume};

nvinfer1::WeightsscalePower{nvinfer1::DataType::kFLOAT,nullptr,volume};

//!!注意這里還需要對shift、scale、power的values進(jìn)行賦值,若只是乘法只需要對scale進(jìn)行賦值就行

//添加張量乘法

nvinfer1::IScaleLayer*Scale_layer=m_network-addScale(Constant_layer-getOutput(0),scalemode,scaleShift,sca

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲(chǔ)空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論