Broadcast廣播機(jī)制在PytorchTensorNumpy中的使用詳解_第1頁(yè)
Broadcast廣播機(jī)制在PytorchTensorNumpy中的使用詳解_第2頁(yè)
Broadcast廣播機(jī)制在PytorchTensorNumpy中的使用詳解_第3頁(yè)
Broadcast廣播機(jī)制在PytorchTensorNumpy中的使用詳解_第4頁(yè)
Broadcast廣播機(jī)制在PytorchTensorNumpy中的使用詳解_第5頁(yè)
已閱讀5頁(yè),還剩2頁(yè)未讀, 繼續(xù)免費(fèi)閱讀

下載本文檔

版權(quán)說(shuō)明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請(qǐng)進(jìn)行舉報(bào)或認(rèn)領(lǐng)

文檔簡(jiǎn)介

第Broadcast廣播機(jī)制在PytorchTensorNumpy中的使用詳解目錄1.什么是廣播機(jī)制2.廣播機(jī)制的規(guī)則3.代碼舉例4.原地操作

1.什么是廣播機(jī)制

根據(jù)線性代數(shù)的運(yùn)算規(guī)則我們知道,矩陣運(yùn)算往往都是在兩個(gè)矩陣維度相同或者相匹配時(shí)才能運(yùn)算。比如加減法需要兩個(gè)矩陣的維度相同,乘法需要前一個(gè)矩陣的列數(shù)與后一個(gè)矩陣的行數(shù)相等。那么在numpy、tensor里也是同樣的道理,但是在機(jī)器學(xué)習(xí)的某些算法中會(huì)出現(xiàn)兩個(gè)維度不相同也不匹配的矩陣進(jìn)行運(yùn)算,那么這時(shí)候就需要用廣播機(jī)制來(lái)解決,通過(guò)廣播機(jī)制,其tensor參數(shù)可以自動(dòng)擴(kuò)展為相等大小(不需要復(fù)制數(shù)據(jù))。下面我們以tensor為例來(lái)解釋什么是廣播機(jī)制。

2.廣播機(jī)制的規(guī)則

先來(lái)說(shuō)下廣播機(jī)制的規(guī)則,只有遵循下面的規(guī)則兩個(gè)張量才可以進(jìn)行廣播運(yùn)算。

每個(gè)tensor至少有一個(gè)維度;

遍歷tensor所有維度時(shí),從末尾開(kāi)始遍歷(從右往左開(kāi)始遍歷),兩個(gè)tensor存在下列情況

tensor維度相等。

tensor維度不等且其中一個(gè)維度為1或者不存在。

滿足上面的條件才可以進(jìn)行廣播機(jī)制。

3.代碼舉例

相同維度,一定可以broadcast:

importtorch

x=torch.rand(1,2,3)

y=torch.rand(1,2,3)

z=x+y

print(x.shape)

print(y.shape)

print(z.shape)

print(x)

print(y)

print(z)

輸出結(jié)果如下:

torch.Size([1,2,3])

torch.Size([1,2,3])

torch.Size([1,2,3])

tensor([[[0.0322,0.2378,0.4711],

[0.9191,0.0802,0.4002]]])

tensor([[[0.5645,0.9541,0.3089],

[0.7633,0.7400,0.7507]]])

tensor([[[0.5966,1.1919,0.7800],

[1.6825,0.8202,1.1509]]])

有一個(gè)張量沒(méi)有維度,一定不可以進(jìn)行broadcast:

importtorch

x=torch.rand(0)

y=torch.rand(1,2,3)

print(x.shape)

print(y.shape)

z=x+y

print(z.shape)

print(x)

print(y)

print(z)

輸出結(jié)果:

torch.Size([0])

torch.Size([1,2,3])

Traceback(mostrecentcalllast):

FileD:/program/Test/broadcast/test.py,line8,inmodule

z=x+y

RuntimeError:Thesizeoftensora(0)mustmatchthesizeoftensorb(3)atnon-singletondimension2

有一個(gè)張量缺少維度,一定可以進(jìn)行broadcast:

importtorch

x=torch.rand(1,2,3,4)

y=torch.rand(2,3,4)

print(x.shape)

print(y.shape)

z=x+y

print(z.shape)

print(x)

print(y)

print(z)

輸出結(jié)果:

torch.Size([1,2,3,4])

torch.Size([2,3,4])

torch.Size([1,2,3,4])

tensor([[[[0.0094,0.1863,0.2657,0.3782],

[0.3296,0.7454,0.2080,0.4156],

[0.2092,0.5414,0.1053,0.3872]],

[[0.8161,0.3554,0.7352,0.2116],

[0.7459,0.1662,0.7555,0.4548],

[0.2611,0.0353,0.1862,0.5948]]]])

tensor([[[0.4637,0.3938,0.2039,0.3892],

[0.4146,0.8713,0.3947,0.5345],

[0.2401,0.3800,0.3747,0.8381]],

[[0.0459,0.1242,0.3529,0.1527],

[0.2361,0.2850,0.8671,0.8040],

[0.6575,0.4075,0.8156,0.2638]]])

tensor([[[[0.4730,0.5801,0.4695,0.7674],

[0.7442,1.6167,0.6027,0.9501],

[0.4493,0.9214,0.4800,1.2253]],

[[0.8620,0.4796,1.0881,0.3643],

[0.9820,0.4512,1.6227,1.2588],

[0.9186,0.4428,1.0018,0.8586]]]])

上面的張量y跟張量x相比缺少一個(gè)維度,根據(jù)廣播機(jī)制的規(guī)則我們從最后一個(gè)維度進(jìn)行匹配,后面三個(gè)維度都一樣,張量y的缺少一個(gè)維度,于是觸發(fā)廣播機(jī)制。

兩個(gè)張量的維度不相等,其中有一個(gè)張量的對(duì)應(yīng)維度為1或者缺失,一定可以進(jìn)行broadcast:

importtorch

x=torch.rand(1,2,3,4)

y=torch.rand(2,1,1)

print(x.shape)

print(y.shape)

z=x+y

print(z.shape)

print(x)

print(y)

print(z)

輸出結(jié)果:

torch.Size([1,2,3,4])

torch.Size([2,1,1])

torch.Size([1,2,3,4])

tensor([[[[0.8670,0.0134,0.7929,0.4109],

[0.3595,0.8457,0.2819,0.8470],

[0.5040,0.9281,0.9161,0.7305]],

[[0.3798,0.3866,0.4680,0.5744],

[0.6984,0.6501,0.2235,0.3099],

[0.9861,0.8598,0.7635,0.3238]]]])

tensor([[[0.3393]],

[[0.1775]]])

tensor([[[[1.2062,0.3527,1.1322,0.7501],

[0.6987,1.1850,0.6212,1.1863],

[0.8433,1.2674,1.2554,1.0698]],

[[0.5574,0.5641,0.6455,0.7519],

[0.8759,0.8276,0.4010,0.4875],

[1.1636,1.0373,0.9410,0.5013]]]])

以上就是廣播機(jī)制的操作,只要記住幾個(gè)規(guī)則就行了,注意tensor在進(jìn)行運(yùn)算的時(shí)候是從后往前匹配運(yùn)算的。

4.原地操作

在進(jìn)行廣播機(jī)制的時(shí)候我們要注意一個(gè)原地操作運(yùn)算,什么是原地操作運(yùn)算?原地操作運(yùn)算就是指改變一個(gè)tensor的值的時(shí)候,不經(jīng)過(guò)復(fù)制操作,而是直接在原來(lái)的內(nèi)存上改變它的值。在pytorch中經(jīng)常加后綴來(lái)代表原地操作符,例:.add_()、.scatter(),原地操作不允許tensor使用廣播機(jī)制那樣來(lái)改變張量形狀維度大小,如下例子所示。

importtorch

x=torch.rand(1,3,1)

y=torch.rand(3,1,7)

print(x.shape)

print(y.shape)

z=x.add_(y)

print(z.shape)

print(x)

print(y)

print(z)

輸出結(jié)果:

torch.Size([1,3,1])

torch.Size([3,1,7])

Tr

溫馨提示

  • 1. 本站所有資源如無(wú)特殊說(shuō)明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請(qǐng)下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請(qǐng)聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁(yè)內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒(méi)有圖紙預(yù)覽就沒(méi)有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫(kù)網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請(qǐng)與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。

最新文檔

評(píng)論

0/150

提交評(píng)論