Python實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測_第1頁
Python實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測_第2頁
Python實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測_第3頁
Python實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測_第4頁
Python實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測_第5頁
全文預(yù)覽已結(jié)束

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請(qǐng)進(jìn)行舉報(bào)或認(rèn)領(lǐng)

文檔簡介

第Python實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測這部分代碼是顯示測試集當(dāng)中前五張圖片,運(yùn)行后會(huì)顯示5張拼接的圖片

由于這個(gè)數(shù)據(jù)集的圖片都比較小都是32x32的尺寸,有些可能也看的不太清楚,圖中顯示的是真實(shí)標(biāo)簽,注:顯示圖片的代碼可能會(huì)這個(gè)報(bào)警(ClippinginputdatatothevalidrangeforimshowwithRGBdata([0…1]forfloatsor[0…255]forintegers).),警告解決的方法:將圖片數(shù)組轉(zhuǎn)成uint8類型即可,即plt.imshow(npimg.astype(‘uint8'),但是那樣顯示出來的圖片會(huì)變,所以暫時(shí)可以先不用管。

(5).初始化模型

數(shù)據(jù)圖片處理完了,下面就是我們的正式訓(xùn)練過程

net=LeNet()

#定義損失函數(shù),nn.CrossEntropyLoss()自帶softmax函數(shù),所以模型的最后一層不需要softmax進(jìn)行激活

loss_function=nn.CrossEntropyLoss()

#定義優(yōu)化器,優(yōu)化模型所有參數(shù)

optimizer=optim.Adam(net.parameters(),lr=0.001)

首先初始化LeNet網(wǎng)絡(luò),定義交叉熵?fù)p失函數(shù),以及Adam優(yōu)化器,關(guān)于注釋寫的,我們可以ctrl+鼠標(biāo)左鍵查看CrossEntropyLoss(),翻到CrossEntropyLoss類,可以看到注釋寫的這個(gè)標(biāo)準(zhǔn)包含LogSoftmax函數(shù),所以搭建LetNet模型的最后一層沒有使用softmax激活函數(shù)

(6).訓(xùn)練模型及保存模型參數(shù)

forepochinrange(5):

#初始損失設(shè)置為0

running_loss=0

#循環(huán)訓(xùn)練集,從1開始

forstep,datainenumerate(train_loader,start=1):

inputs,labels=data

#優(yōu)化器的梯度清零,每次循環(huán)都需要清零,否則梯度會(huì)無限疊加,相當(dāng)于增加批次大小

optimizer.zero_grad()

#將圖片數(shù)據(jù)輸入模型中得到輸出

outputs=net(inputs)

#傳入預(yù)測值和真實(shí)值,計(jì)算當(dāng)前損失值

loss=loss_function(outputs,labels)

#損失反向傳播

loss.backward()

#進(jìn)行梯度更新(更新W,b)

optimizer.step()

#計(jì)算該輪的總損失,因?yàn)閘oss是tensor類型,所以需要用item()取到值

running_loss+=loss.item()

#每500次進(jìn)行日志的打印,對(duì)測試集進(jìn)行測試

ifstep%500==0:

#torch.no_grad()就是上下文管理,測試時(shí)不需要梯度更新,不跟蹤梯度

withtorch.no_grad():

#傳入所有測試集圖片進(jìn)行預(yù)測

outputs=net(test_img)

#torch.max()中dim=1是因?yàn)榻Y(jié)果為(batch,10)的形式,我們只需要取第二個(gè)維度的最大值,第二個(gè)維度是包含十個(gè)類別每個(gè)類別的概率的向量

#max這個(gè)函數(shù)返回[最大值,最大值索引],我們只需要取索引就行了,所以用[1]

predict_y=torch.max(outputs,dim=1)[1]

#(predict_y==test_label)相同返回True,不相等返回False,sum()對(duì)正確結(jié)果進(jìn)行疊加,最后除測試集標(biāo)簽的總個(gè)數(shù)

#因?yàn)橛?jì)算的變量都是tensor,所以需要用item()拿到取值

accuracy=(predict_y==test_label).sum().item()/test_label.size(0)

#running_loss/500是計(jì)算每一個(gè)step的loss,即每一步的損失

print('[%d,%5d]train_loss:%.3ftest_accuracy:%.3f'%

(epoch+1,step,running_loss/500,accuracy))

running_loss=0.0

print('FinishedTraining!')

save_path='lenet.pth'

#保存模型,字典形式

torch.save(net.state_dict(),save_path)

這段代碼注釋寫的很清楚,大家仔細(xì)看就能看懂,流程不復(fù)雜,多看幾遍就能理解,最后再對(duì)訓(xùn)練好的模型進(jìn)行保存就好了(* ̄︶ ̄)

2.預(yù)測腳本

上面已經(jīng)訓(xùn)練好了模型,得到了lenet.pth參數(shù)文件,預(yù)測就很簡單了,可以去網(wǎng)上隨便找一張數(shù)據(jù)集包含的類別圖片,將模型參數(shù)文件載入模型,通過對(duì)圖像進(jìn)行一點(diǎn)處理,喂入模型即可,下面奉上代碼:

importtorch

importnumpyasnp

importtorchvision.transformsastransforms

fromPILimportImage

frompytorch.lenet.modelimportLeNet

classes=('plane','car','bird','cat','deer',

'dog','frog','horse','ship','truck')

transforms=transforms.Compose(

#對(duì)數(shù)據(jù)圖片調(diào)整大小

[transforms.Resize([32,32]),

transforms.ToTensor(),

transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]

net=LeNet()

#加載預(yù)訓(xùn)練模型

net.load_state_dict(torch.load('lenet.pth'))

#網(wǎng)上隨便找的貓的圖片

img_path='../../Photo/cat2.jpg'

img=Image.open(img_path)

#圖片的處理

img=transforms(img)

#增加一個(gè)維度,(channels,height,width)-------(batch,channels,height,width),pytorch要求必須輸入這樣的shape

img=torch.unsqueeze(img,dim=0)

withtorch.no_grad():

output=net(img)

#dim=1,只取[batch,10]中10個(gè)類別的那個(gè)維度,取預(yù)測結(jié)果的最大值索引,并轉(zhuǎn)換為numpy類型

prediction1=torch.max(output,dim=1)[1].data.numpy()

#用softmax()預(yù)測出一個(gè)概率矩陣

prediction2=torch.softmax(output,dim=1)

#得到概率最大的值得索引

prediction2=np.argmax(predictio

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請(qǐng)下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請(qǐng)聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請(qǐng)與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。

評(píng)論

0/150

提交評(píng)論