pytorch 如何把圖像數(shù)據(jù)集進(jìn)行劃分成train,test和val_第1頁
pytorch 如何把圖像數(shù)據(jù)集進(jìn)行劃分成train,test和val_第2頁
pytorch 如何把圖像數(shù)據(jù)集進(jìn)行劃分成train,test和val_第3頁
pytorch 如何把圖像數(shù)據(jù)集進(jìn)行劃分成train,test和val_第4頁
pytorch 如何把圖像數(shù)據(jù)集進(jìn)行劃分成train,test和val_第5頁
全文預(yù)覽已結(jié)束

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進(jìn)行舉報或認(rèn)領(lǐng)

文檔簡介

第pytorch如何把圖像數(shù)據(jù)集進(jìn)行劃分成train,test和val2、目錄結(jié)構(gòu):

|---data

|---dslr

|---images

|---back_pack

|---a.jpg

|---b.jpg

3、轉(zhuǎn)換后的格式如圖

目錄結(jié)構(gòu)為:

|---datanews

|---dslr

|---images

|---test

|---train

|---valid

|---back_pack

|---a.jpg

|---b.jpg

4、代碼如下:

4.1先創(chuàng)建同樣結(jié)構(gòu)的層級結(jié)構(gòu)

4.2然后講原始數(shù)據(jù)按照比例劃分

4.3移入到對應(yīng)的文件目錄里面

importos,random,shutil

defmake_dir(source,target):

創(chuàng)建和源文件相似的文件路徑函數(shù)

:paramsource:源文件位置

:paramtarget:目標(biāo)文件位置

dir_names=os.listdir(source)

fornamesindir_names:

foriin['train','valid','test']:

path=target+'/'+i+'/'+names

ifnotos.path.exists(path):

os.makedirs(path)

defdivideTrainValiTest(source,target):

創(chuàng)建和源文件相似的文件路徑

:paramsource:源文件位置

:paramtarget:目標(biāo)文件位置

#得到源文件下的種類

pic_name=os.listdir(source)

#對于每一類里的數(shù)據(jù)進(jìn)行操作

forclassesinpic_name:

#得到這一種類的圖片的名字

pic_classes_name=os.listdir(os.path.join(source,classes))

random.shuffle(pic_classes_name)

#按照8:1:1比例劃分

train_list=pic_classes_name[0:int(0.8*len(pic_classes_name))]

valid_list=pic_classes_name[int(0.8*len(pic_classes_name)):int(0.9*len(pic_classes_name))]

test_list=pic_classes_name[int(0.9*len(pic_classes_name)):]

#對于每個圖片,移入到對應(yīng)的文件夾里面

fortrain_picintrain_list:

shutil.copyfile(source+'/'+classes+'/'+train_pic,target+'/train/'+classes+'/'+train_pic)

forvalidation_picinvalid_list:

shutil.copyfile(source+'/'+classes+'/'+validation_pic,

target+'/valid/'+classes+'/'+validation_pic)

fortest_picintest_list:

shutil.copyfile(source+'/'+classes+'/'+test_pic,target+'/test/'+classes+'/'+test_pic)

if__name__=='__main__':

filepath=r'../data/dslr/images'

dist=r'../datanews/dslr/images'

make_dir(filepath,dist)

divideTrainValiTest(filepath,dist)

補(bǔ)充:pytorch中數(shù)據(jù)集的劃分方法及eError:take():argument'index'(position1)mustbeTensor,notnumpy.ndarray錯誤原因

在使用pytorch框架時,難免需要對數(shù)據(jù)集進(jìn)行訓(xùn)練集和驗證集的劃分,一般使用sklearn.model_selection中的train_test_split方法

該方法使用如下:

fromsklearn.model_selectionimporttrain_test_split

importnumpyasnp

importtorch

importtorch.autogradimportVariable

fromtorch.utils.dataimportDataLoader

traindata=np.load(train_path)#image_num*W*H

trainlabel=np.load(train_label_path)

train_data=traindata[:,np.newaxis,...]

train_label_data=trainlabel[:,np.newaxis,...]

x_tra,x_val,y_tra,y_val=train_test_split(train_data,train_label_data,test_size=0.1,random_state=0)#訓(xùn)練集和驗證集使用9:1

x_tra=Variable(torch.from_numpy(x_tra))

x_tra=x_tra.float()

y_tra=Variable(torch.from_numpy(y_tra))

y_tra=y_tra.float()

x_val=Variable(torch.from_numpy(x_val))

x_val=x_val.float()

y_val=Variable(torch.from_numpy(y_val))

y_val=y_val.float()

#訓(xùn)練集的DataLoader

traindataset=torch.utils.data.TensorDataset(x_tra,y_tra)

trainloader=DataLoader(dataset=traindataset,num_workers=opt.threads,batch_size=8,shuffle=True)

#驗證集的DataLoader

validataset=torch.utils.data.TensorDataset(x_val,y_val)

valiloader=DataLoader(dataset=validataset,num_workers=opt.threads,batch_size=opt.batchSize,shuffle=True)

注意:如果按照如下方式使用,就會報eError:take():argument'index'(position1)mustbeTensor,notnumpy.ndarray錯誤

fromsklearn.model_selectionimporttrain_test_split

importnumpyasnp

importtorch

importtorch.autogradimportVariable

fromtorch.utils.dataimportDataLoader

traindata=np.load(train_path)#image_num*W*H

trainlabel=np.load(train_label_path)

train_data=traindata[:,np.newaxis,...]

train_label_data=trainlabel[:,np.newaxis,...]

x_train=Variable(torch.from_numpy(train_data))

x_train=x_train.float()

y_train=Variable(torch.from_numpy(train_label_data))

y_train=y_train.float()

#將原始的訓(xùn)練數(shù)據(jù)集分為訓(xùn)練集和驗證集,后面就可以使用早停機(jī)制

x_tra,x_v

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護(hù)處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

最新文檔

評論

0/150

提交評論