A demo for feature extraction with vgg19 in pytorch

这里提取灰度图的特征,所以我把它堆叠了三次变成了三通道,vgg-19的预训练模型可以在pytorch提供的官方地址下载

import torch,os
import scipy.io as sio
import numpy as np
import scipy.misc as sc
from torchvision import models
from torch.autograd import Variable

model = models.vgg19()
model.load_state_dict(torch.load('models/vgg19-dcbb9e9d.pth'))
model.cuda()

data_file = os.listdir('img/')
for pic in data_file:
    pic_path = 'img/'+pic
    print(pic_path)
    data = sc.imread(pic_path)
    data = np.resize(data,[224,224])
    img = np.zeros((1, 3, 224, 224)).astype(np.float32)
    for i in range(3):
        img[:,i,:,:] = data
    img = torch.from_numpy(img)
    img = Variable(img).cuda()
    feature = model(img)
    tmp = []
    for key,parm in enumerate(model.classifier.parameters()):
        if key == 0 or key == 2 or key == 4:
            continue
        parm = parm.cpu()
        tmp.append(parm.data.numpy())
    sio.savemat('feature/'+pic[:-4]+'.mat',{'feature_4096_first':tmp[0],
                                            'feature_4096_second': tmp[1],
                                            'feature_1000': tmp[2]})
    # break

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。