詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式_第1頁(yè)
詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式_第2頁(yè)
詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式_第3頁(yè)
詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式_第4頁(yè)
詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式_第5頁(yè)
全文預(yù)覽已結(jié)束

下載本文檔

版權(quán)說(shuō)明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請(qǐng)進(jìn)行舉報(bào)或認(rèn)領(lǐng)

文檔簡(jiǎn)介

第詳解TensorFlow訓(xùn)練網(wǎng)絡(luò)兩種方式第一種是要加載全部數(shù)據(jù)形成一個(gè)tensor,然后調(diào)用model.fit()然后指定參數(shù)batch_size進(jìn)行將所有數(shù)據(jù)進(jìn)行分批訓(xùn)練

第二種是自己先將數(shù)據(jù)分批形成一個(gè)迭代器,然后遍歷這個(gè)迭代器,分別訓(xùn)練每個(gè)批次的數(shù)據(jù)

方式一:通過(guò)迭代器

IMAGE_SIZE=1000

#step1:加載數(shù)據(jù)集

(train_images,train_labels),(val_images,val_labels)=tf.keras.datasets.mnist.load_data()

#step2:將圖像歸一化

train_images,val_images=train_images/255.0,val_images/255.0

#step3:設(shè)置訓(xùn)練集大小

train_images=train_images[:IMAGE_SIZE]

val_images=val_images[:IMAGE_SIZE]

train_labels=train_labels[:IMAGE_SIZE]

val_labels=val_labels[:IMAGE_SIZE]

#step4:將圖像的維度變?yōu)?IMAGE_SIZE,28,28,1)

train_images=tf.expand_dims(train_images,axis=3)

val_images=tf.expand_dims(val_images,axis=3)

#step5:將圖像的尺寸變?yōu)?32,32)

train_images=tf.image.resize(train_images,[32,32])

val_images=tf.image.resize(val_images,[32,32])

#step6:將數(shù)據(jù)變?yōu)榈?/p>

train_loader=tf.data.Dataset.from_tensor_slices((train_images,train_labels)).batch(32)

val_loader=tf.data.Dataset.from_tensor_slices((val_images,val_labels)).batch(IMAGE_SIZE)

#step5:導(dǎo)入模型

model=LeNet5()

#讓模型知道輸入數(shù)據(jù)的形式

model.build(input_shape=(1,32,32,1))

#結(jié)局OutputShape為multiple

model.call(Input(shape=(32,32,1)))

#step6:編譯模型

pile(optimizer='adam',

loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

metrics=['accuracy'])

#權(quán)重保存路徑

checkpoint_path="./weight/cp.ckpt"

#回調(diào)函數(shù),用戶保存權(quán)重

save_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,

save_best_only=True,

save_weights_only=True,

monitor='val_loss',

verbose=0)

EPOCHS=11

forepochinrange(1,EPOCHS):

#每個(gè)批次訓(xùn)練集誤差

train_epoch_loss_avg=tf.keras.metrics.Mean()

#每個(gè)批次訓(xùn)練集精度

train_epoch_accuracy=tf.keras.metrics.SparseCategoricalAccuracy()

#每個(gè)批次驗(yàn)證集誤差

val_epoch_loss_avg=tf.keras.metrics.Mean()

#每個(gè)批次驗(yàn)證集精度

val_epoch_accuracy=tf.keras.metrics.SparseCategoricalAccuracy()

forx,yintrain_loader:

history=model.fit(x,

validation_data=val_loader,

callbacks=[save_callback],

verbose=0)

#更新誤差,保留上次

train_epoch_loss_avg.update_state(history.history['loss'][0])

#更新精度,保留上次

train_epoch_accuracy.update_state(y,model(x,training=True))

val_epoch_loss_avg.update_state(history.history['val_loss'][0])

val_epoch_accuracy.update_state(next(iter(val_loader))[1],model(next(iter(val_loader))[0],training=True))

#使用.result()計(jì)算每個(gè)批次的誤差和精度結(jié)果

print("Epoch{:d}:trainLoss:{:.3f},trainAccuracy:{:.3%}valLoss:{:.3f},valAccuracy:{:.3%}".format(epoch,

train_epoch_loss_avg.result(),

train_epoch_accuracy.result(),

val_epoch_loss_avg.result(),

val_epoch_accuracy.result()))

方式二:適用model.fit()進(jìn)行分批訓(xùn)練

importmodel_sequential

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.mnist.load_data()

#step2:將圖像歸一化

train_images,test_images=train_images/255.0,test_images/255.0

#step3:將圖像的維度變?yōu)?60000,28,28,1)

train_images=tf.expand_dims(train_images,axis=3)

test_images=tf.expand_dims(test_images,axis=3)

#step4:將圖像尺寸改為(60000,32,32,1)

train_images=tf.image.resize(train_images,[32,32])

test_images=tf.image.resize(test_images,[32,32])

#step5:導(dǎo)入模型

#history=LeNet5()

history=model_sequential.LeNet()

#讓模型知道輸入數(shù)據(jù)的形式

history.build(input_shape=(1,32,32,1))

#history(tf.zeros([1,32,32,1]))

#結(jié)局OutputShape為multiple

history.call(Input(shape=(32,32,1)))

history.summary()

#step6:編譯模型

pile(optimizer='adam',

loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

metrics=['accuracy'])

#權(quán)重保存路徑

checkpoint_path="./weight/cp.ckpt"

#回調(diào)函數(shù),用戶保存權(quán)重

save_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,

save_best_only=True,

save_weights_only=True,

monitor='val_loss',

verbose=1)

#step7:訓(xùn)練模型

history=history.fit(train_images,

train_labels,

溫馨提示

  • 1. 本站所有資源如無(wú)特殊說(shuō)明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請(qǐng)下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請(qǐng)聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁(yè)內(nèi)容里面會(huì)有圖紙預(yù)覽,若沒(méi)有圖紙預(yù)覽就沒(méi)有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫(kù)網(wǎng)僅提供信息存儲(chǔ)空間,僅對(duì)用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對(duì)用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對(duì)任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請(qǐng)與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時(shí)也不承擔(dān)用戶因使用這些下載資源對(duì)自己和他人造成任何形式的傷害或損失。

最新文檔

評(píng)論

0/150

提交評(píng)論