下載本文檔
版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請(qǐng)進(jìn)行舉報(bào)或認(rèn)領(lǐng)
文檔簡介
第Python實(shí)現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測這部分代碼是顯示測試集當(dāng)中前五張圖片,運(yùn)行后會(huì)顯示5張拼接的圖片
由于這個(gè)數(shù)據(jù)集的圖片都比較小都是32x32的尺寸,有些可能也看的不太清楚,圖中顯示的是真實(shí)標(biāo)簽,注:顯示圖片的代碼可能會(huì)這個(gè)報(bào)警(ClippinginputdatatothevalidrangeforimshowwithRGBdata([0…1]forfloatsor[0…255]forintegers).),警告解決的方法:將圖片數(shù)組轉(zhuǎn)成uint8類型即可,即plt.imshow(npimg.astype(‘uint8'),但是那樣顯示出來的圖片會(huì)變,所以暫時(shí)可以先不用管。
(5).初始化模型
數(shù)據(jù)圖片處理完了,下面就是我們的正式訓(xùn)練過程
net=LeNet()
#定義損失函數(shù),nn.CrossEntropyLoss()自帶softmax函數(shù),所以模型的最后一層不需要softmax進(jìn)行激活
loss_function=nn.CrossEntropyLoss()
#定義優(yōu)化器,優(yōu)化模型所有參數(shù)
optimizer=optim.Adam(net.parameters(),lr=0.001)
首先初始化LeNet網(wǎng)絡(luò),定義交叉熵?fù)p失函數(shù),以及Adam優(yōu)化器,關(guān)于注釋寫的,我們可以ctrl+鼠標(biāo)左鍵查看CrossEntropyLoss(),翻到CrossEntropyLoss類,可以看到注釋寫的這個(gè)標(biāo)準(zhǔn)包含LogSoftmax函數(shù),所以搭建LetNet模型的最后一層沒有使用softmax激活函數(shù)
(6).訓(xùn)練模型及保存模型參數(shù)
forepochinrange(5):
#初始損失設(shè)置為0
running_loss=0
#循環(huán)訓(xùn)練集,從1開始
forstep,datainenumerate(train_loader,start=1):
inputs,labels=data
#優(yōu)化器的梯度清零,每次循環(huán)都需要清零,否則梯度會(huì)無限疊加,相當(dāng)于增加批次大小
optimizer.zero_grad()
#將圖片數(shù)據(jù)輸入模型中得到輸出
outputs=net(inputs)
#傳入預(yù)測值和真實(shí)值,計(jì)算當(dāng)前損失值
loss=loss_function(outputs,labels)
#損失反向傳播
loss.backward()
#進(jìn)行梯度更新(更新W,b)
optimizer.step()
#計(jì)算該輪的總損失,因?yàn)閘oss是tensor類型,所以需要用item()取到值
running_loss+=loss.item()
#每500次進(jìn)行日志的打印,對(duì)測試集進(jìn)行測試
ifstep%500==0:
#torch.no_grad()就是上下文管理,測試時(shí)不需要梯度更新,不跟蹤梯度
withtorch.no_grad():
#傳入所有測試集圖片進(jìn)行預(yù)測
outputs=net(test_img)
#torch.max()中dim=1是因?yàn)榻Y(jié)果為(batch,10)的形式,我們只需要取第二個(gè)維度的最大值,第二個(gè)維度是包含十個(gè)類別每個(gè)類別的概率的向量
#max這個(gè)函數(shù)返回[最大值,最大值索引],我們只需要取索引就行了,所以用[1]
predict_y=torch.max(outputs,dim=1)[1]
#(predict_y==test_label)相同返回True,不相等返回False,sum()對(duì)正確結(jié)果進(jìn)行疊加,最后除測試集標(biāo)簽的總個(gè)數(shù)
#因?yàn)橛?jì)算的變量都是tensor,所以需要用item()拿到取值
accuracy=(predict_y==test_label).sum().item()/test_label.size(0)
#running_loss/500是計(jì)算每一個(gè)step的loss,即每一步的損失
print('[%d,%5d]train_loss:%.3ftest_accuracy:%.3f'%
(epoch+1,step,running_loss/500,accuracy))
running_loss=0.0
print('FinishedTraining!')
save_path='lenet.pth'
#保存模型,字典形式
torch.save(net.state_dict(),save_path)
這段代碼注釋寫的很清楚,大家仔細(xì)看就能看懂,流程不復(fù)雜,多看幾遍就能理解,最后再對(duì)訓(xùn)練好的模型進(jìn)行保存就好了(* ̄︶ ̄)
2.預(yù)測腳本
上面已經(jīng)訓(xùn)練好了模型,得到了lenet.pth參數(shù)文件,預(yù)測就很簡單了,可以去網(wǎng)上隨便找一張數(shù)據(jù)集包含的類別圖片,將模型參數(shù)文件載入模型,通過對(duì)圖像進(jìn)行一點(diǎn)處理,喂入模型即可,下面奉上代碼:
importtorch
importnumpyasnp
importtorchvision.transformsastransforms
fromPILimportImage
frompytorch.lenet.modelimportLeNet
classes=('plane','car','bird','cat','deer',
'dog','frog','horse','ship','truck')
transforms=transforms.Compose(
#對(duì)數(shù)據(jù)圖片調(diào)整大小
[transforms.Resize([32,32]),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
net=LeNet()
#加載預(yù)訓(xùn)練模型
net.load_state_dict(torch.load('lenet.pth'))
#網(wǎng)上隨便找的貓的圖片
img_path='../../Photo/cat2.jpg'
img=Image.open(img_path)
#圖片的處理
img=transforms(img)
#增加一個(gè)維度,(channels,height,width)-------(batch,channels,height,width),pytorch要求必須輸入這樣的shape
img=torch.unsqueeze(img,dim=0)
withtorch.no_grad():
output=net(img)
#dim=1,只取[batch,10]中10個(gè)類別的那個(gè)維度,取預(yù)測結(jié)果的最大值索引,并轉(zhuǎn)換為numpy類型
prediction1=torch.max(output,dim=1)[1].data.numpy()
#用softmax()預(yù)測出一個(gè)概率矩陣
prediction2=torch.softmax(output,dim=1)
#得到概率最大的值得索引
prediction2=np.argmax(predictio
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請(qǐng)下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請(qǐng)聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
- 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請(qǐng)與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。
最新文檔
- 衛(wèi)生用品更衣室管理制度
- 衛(wèi)生院行風(fēng)督查制度
- 衛(wèi)生院三病物資管理制度
- 生活區(qū)衛(wèi)生物品管理制度
- 衛(wèi)生院疾病預(yù)防管理制度
- 衛(wèi)生所規(guī)范管理制度
- 養(yǎng)殖場日常衛(wèi)生管理制度
- 幼兒園8項(xiàng)衛(wèi)生管理制度
- 衛(wèi)生所首診負(fù)責(zé)制度
- 衛(wèi)生院新冠病人轉(zhuǎn)診制度
- 2026屆杭州高級(jí)中學(xué)高二上數(shù)學(xué)期末聯(lián)考試題含解析
- 2026年陜西氫能產(chǎn)業(yè)發(fā)展有限公司所屬單位社會(huì)公開招聘備考題庫及1套參考答案詳解
- 2026年及未來5年中國無取向硅鋼片行業(yè)市場深度分析及發(fā)展趨勢預(yù)測報(bào)告
- 棄土場規(guī)范規(guī)章制度
- 2026年水下機(jī)器人勘探報(bào)告及未來五至十年深海資源報(bào)告
- 2025年3月29日事業(yè)單位聯(lián)考(職測+綜應(yīng))ABCDE類筆試真題及答案解析
- 雙重預(yù)防體系建設(shè)自評(píng)報(bào)告模板
- 高血壓教學(xué)查房復(fù)習(xí)過程教案(2025-2026學(xué)年)
- 感控PDCA持續(xù)質(zhì)量改進(jìn)
- 2025年云服務(wù)器采購合同協(xié)議
- 補(bǔ)氣血培訓(xùn)課件
評(píng)論
0/150
提交評(píng)論