版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報(bào)或認(rèn)領(lǐng)
文檔簡介
第Pytorch實(shí)現(xiàn)常用乘法算子TensorRT的示例代碼目錄1.乘法運(yùn)算總覽2.乘法算子實(shí)現(xiàn)2.1矩陣乘算子實(shí)現(xiàn)2.2點(diǎn)乘算子實(shí)現(xiàn)本文介紹一下Pytorch中常用乘法的TensorRT實(shí)現(xiàn)。
pytorch用于訓(xùn)練,TensorRT用于推理是很多AI應(yīng)用開發(fā)的標(biāo)配。大家往往更加熟悉pytorch的算子,而不太熟悉TensorRT的算子,這里拿比較常用的乘法運(yùn)算在兩種框架下的實(shí)現(xiàn)做一個(gè)對比,可能會(huì)有更加直觀一些的認(rèn)識。
1.乘法運(yùn)算總覽
先把pytorch中的一些常用的乘法運(yùn)算進(jìn)行一個(gè)總覽:
torch.mm:用于兩個(gè)矩陣(不包括向量)的乘法,如維度(m,n)的矩陣乘以維度(n,p)的矩陣;torch.bmm:用于帶batch的三維向量的乘法,如維度(b,m,n)的矩陣乘以維度(b,n,p)的矩陣;torch.mul:用于同維度矩陣的逐像素點(diǎn)相乘,也即點(diǎn)乘,如維度(m,n)的矩陣點(diǎn)乘維度(m,n)的矩陣。該方法支持廣播,也即支持矩陣和元素點(diǎn)乘;torch.mv:用于矩陣和向量的乘法,矩陣在前,向量在后,如維度(m,n)的矩陣乘以維度為(n)的向量,輸出維度為(m);torch.matmul:用于兩個(gè)張量相乘,或矩陣與向量乘法,作用包含torch.mm、torch.bmm、torch.mv;@:作用相當(dāng)于torch.matmul;*:作用相當(dāng)于torch.mul;
如上進(jìn)行了一些具體羅列,可以歸納出,常用的乘法無非兩種:矩陣乘和點(diǎn)乘,所以下面分這兩類進(jìn)行介紹。
2.乘法算子實(shí)現(xiàn)
2.1矩陣乘算子實(shí)現(xiàn)
先來看看矩陣乘法的pytorch的實(shí)現(xiàn)(以下實(shí)現(xiàn)在終端):
importtorch
#torch.mm
a=torch.randn(66,99)
b=torch.randn(99,88)
c=torch.mm(a,b)
c.shape
torch.size([66,88])
#torch.bmm
a=torch.randn(3,66,99)
b=torch.randn(3,99,77)
c=torch.bmm(a,b)
c.shape
torch.size([3,66,77])
#torch.mv
a=torch.randn(66,99)
b=torch.randn(99)
c=torch.mv(a,b)
c.shape
torch.size([66])
#torch.matmul
a=torch.randn(32,3,66,99)
b=torch.randn(32,3,99,55)
c=torch.matmul(a,b)
c.shape
torch.size([32,3,66,55])
d=a@b
d.shape
torch.size([32,3,66,55])
來看TensorRT的實(shí)現(xiàn),以上乘法都可使用addMatrixMultiply方法覆蓋,對應(yīng)torch.matmul,先來看該方法的定義:
//!
//!\briefAddaMatrixMultiplylayertothenetwork.
//!\paraminput0Thefirstinputtensor(commonlyA).
//!\paramop0Theoperationtoapplytoinput0.
//!\paraminput1Thesecondinputtensor(commonlyB).
//!\paramop1Theoperationtoapplytoinput1.
//!\seeIMatrixMultiplyLayer
//!\warningInt32tensorsarenotvalidinputtensors.
//!\returnThenewmatrixmultiplylayer,ornullptrifitcouldnotbecreated.
IMatrixMultiplyLayer*addMatrixMultiply(
ITensorinput0,MatrixOperationop0,ITensorinput1,MatrixOperationop1)noexcept
returnmImpl-addMatrixMultiply(input0,op0,input1,op1);
可以看到這個(gè)方法有四個(gè)傳參,對應(yīng)兩個(gè)張量和其operation。來看這個(gè)算子在TensorRT中怎么添加:
//構(gòu)造張量Tensor0
nvinfer1::IConstantLayer*Constant_layer0=m_network-addConstant(tensorShape0,value0);
//構(gòu)造張量Tensor1
nvinfer1::IConstantLayer*Constant_layer1=m_network-addConstant(tensorShape1,value1);
//添加矩陣乘法
nvinfer1::IMatrixMultiplyLayer*Matmul_layer=m_network-addMatrixMultiply(Constant_layer0-getOutput(0),matrix0Type,Constant_layer1-getOutput(0),matrix2Type);
//獲取輸出
matmulOutput=Matmul_layer-getOputput(0);
2.2點(diǎn)乘算子實(shí)現(xiàn)
再來看看點(diǎn)乘的pytorch的實(shí)現(xiàn)(以下實(shí)現(xiàn)在終端):
importtorch
#torch.mul
a=torch.randn(66,99)
b=torch.randn(66,99)
c=torch.mul(a,b)
c.shape
torch.size([66,99])
d=0.125
e=torch.mul(a,d)
e.shape
torch.size([66,99])
f=a*b
f.shape
torch.size([66,99])
來看TensorRT的實(shí)現(xiàn),以上乘法都可使用addScale方法覆蓋,這在圖像預(yù)處理中十分常用,先來看該方法的定義:
//!
//!\briefAddaScalelayertothenetwork.
//!\paraminputTheinputtensortothelayer.
//!Thistensorisrequiredtohaveaminimumof3dimensionsinimplicitbatchmode
//!andaminimumof4dimensionsinexplicitbatchmode.
//!\parammodeThescalingmode.
//!\paramshiftTheshiftvalue.
//!\paramscaleThescalevalue.
//!\parampowerThepowervalue.
//!Iftheweightsareavailable,thenthesizeofweightsaredependentontheScaleMode.
//!For::kUNIFORM,thenumberofweightsequals1.
//!For::kCHANNEL,thenumberofweightsequalsthechanneldimension.
//!For::kELEMENTWISE,thenumberofweightsequalstheproductofthelastthreedimensionsoftheinput.
//!\seeaddScaleNd
//!\seeIScaleLayer
//!\warningInt32tensorsarenotvalidinputtensors.
//!\returnThenewScalelayer,ornullptrifitcouldnotbecreated.
IScaleLayer*addScale(ITensorinput,ScaleModemode,Weightsshift,Weightsscale,Weightspower)noexcept
returnmImpl-addScale(input,mode,shift,scale,power);
可以看到有三個(gè)模式:
kUNIFORM:weights為一個(gè)值,對應(yīng)張量乘一個(gè)元素;kCHANNEL:weights維度和輸入張量通道的c維度對應(yīng),可以做一些以通道為基準(zhǔn)的預(yù)處理;kELEMENTWISE:weights維度和輸入張量的c、h、w對應(yīng),不考慮batch,所以是輸入的后三維;
再來看這個(gè)算子在TensorRT中怎么添加:
//構(gòu)造張量input
nvinfer1::IConstantLayer*Constant_layer=m_network-addConstant(tensorShape,value);
//scalemode選擇,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode=kUNIFORM;
//構(gòu)建Weights類型的shift、scale、power,其中volume為元素?cái)?shù)量
nvinfer1::WeightsscaleShift{nvinfer1::DataType::kFLOAT,nullptr,volume};
nvinfer1::WeightsscaleScale{nvinfer1::DataType::kFLOAT,nullptr,volume};
nvinfer1::WeightsscalePower{nvinfer1::DataType::kFLOAT,nullptr,volume};
//!!注意這里還需要對shift、scale、power的values進(jìn)行賦值,若只是乘法只需要對scale進(jìn)行賦值就行
//添加張量乘法
nvinfer1::IScaleLayer*Scale_layer=m_network-addScale(Constant_layer-getOutput(0),scalemode,scaleShift,sca
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲(chǔ)空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
- 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 二硫化碳生產(chǎn)工測試驗(yàn)證評優(yōu)考核試卷含答案
- 電力通信運(yùn)維員崗前規(guī)章制度考核試卷含答案
- 片基流延工誠信道德能力考核試卷含答案
- 電子玻璃制品鍍膜工安全宣教測試考核試卷含答案
- 安全員考試請假條
- 2025年超細(xì)銀粉末、銀鈀粉、鈀粉、鉑粉項(xiàng)目合作計(jì)劃書
- 2026年智能心率帶項(xiàng)目營銷方案
- 2025年江蘇省南通市中考物理真題卷含答案解析
- 2025年山東省日照市中考英語真題卷含答案解析
- 2025康復(fù)醫(yī)學(xué)與技術(shù)專業(yè)知識題庫及答案
- 民法典物業(yè)管理解讀課件
- 2025年中國汽輪機(jī)導(dǎo)葉片市場調(diào)查研究報(bào)告
- 中班幼兒戶外游戲活動(dòng)實(shí)施現(xiàn)狀研究-以綿陽市Y幼兒園為例
- 特色休閑農(nóng)場設(shè)計(jì)規(guī)劃方案
- 采購部門月度匯報(bào)
- 新華書店管理辦法
- 檔案專業(yè)人員公司招聘筆試題庫及答案
- 工程竣工移交單(移交甲方、物業(yè))
- 來料檢驗(yàn)控制程序(含表格)
- 2025年鈦合金閥項(xiàng)目可行性研究報(bào)告
評論
0/150
提交評論