什么是生成式对抗网络GAN
(本教程的代码,训练数据全部来自—《深度学习框架Pytroch入门与实践》, many thanks,此书采用pytorch 0.4.0版本,API接口与1.0有所差别,1.0版本pytroch中已经不推荐使用Variable)
开发/测试环境
- Ubuntu 18.04
- anaconda3, python3.6
- pycharm
- pytroch 1.0
训练过程
刚开始训练,输入为噪声向量, 生成的图像也是噪声
image.png
训练了几十次迭代之后
image.png
随着迭代次数增加,逐渐产生轮廓,仔细观察刚开始生成的图像为黑白灰度图像,没有彩色信息。
image.png
image.png
image.png
继续迭代,逐渐产生了彩色信息。
image.png
image.png
image.png
image.png
image.png
image.png
image.png
Loss曲线的变化
image.png
image.png
image.png
image.png
image.png
使用GPU进行训练
CPU进行训练太慢了,笔者采用Intel i7 5500u CPU进行训练,一秒钟大概只能迭代一次,而且batch size设置为4~8。之后切换到GPU上(Nvdia 1080ti), 单块GPU, 计算速度为20~30iter/sec, batch size=64, 直观上比CPU计算块20倍多。
image.png
迭代30K次
image.png
Process
深度录屏_选择区域_20190206000807.gif
深度录屏_TeamViewer_20190206000840.gif
深度录屏_TeamViewer_20190206001653.gif
代码
网络定义
GAN网络不同于一般的分类网络,由2部分组成: 生成器,判别器。
生成器
NetG
输入: 1x100x1x1 (NxCxHxW) 100维的噪声向量
输出: 1x3x96x96 3(Channels)x96(Height)x96(Width)的图像
from torch import nn
class NetG(nn.Module):
'''
生成器定义
'''
def __init__(self, opt):
super(NetG, self).__init__()
ngf = opt.ngf # 生成器feature map数
self.main = nn.Sequential(
# 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
nn.ConvTranspose2d(opt.nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# 上一步的输出形状:(ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# 上一步的输出形状: (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# 上一步的输出形状: (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# 上一步的输出形状:(ngf) x 32 x 32
nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
nn.Tanh() # 输出范围 -1~1 故而采用Tanh
# 输出形状:3 x 96 x 96
)
def forward(self, input):
return self.main(input)
判别器
NetD
输入: 1x3x96x96 的图像
输出: 1x1x1x1 的一个数,表示概率值
class NetD(nn.Module):
'''
判别器定义
'''
def __init__(self, opt):
super(NetD, self).__init__()
ndf = opt.ndf
self.main = nn.Sequential(
# 输入 3 x 96 x 96
nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# 输出 (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid() # 输出一个数(概率)
)
def forward(self, input):
return self.main(input).view(-1)
参数配置
- batch_size
- learning_rate
- max_epoch 最大迭代epoch个数
import os
import ipdb
import torch as t
import torchvision as tv
import tqdm
from model import NetG, NetD
from torch.autograd import Variable
from torchnet.meter import AverageValueMeter
class Config(object):
data_path = 'data/' # 数据集存放路径
num_workers = 4 # 多进程加载数据所用的进程数
image_size = 96 # 图片尺寸
batch_size = 16
max_epoch = 200
lr1 = 2e-4 # 生成器的学习率
lr2 = 2e-4 # 判别器的学习率
beta1=0.5 # Adam优化器的beta1参数
gpu=False # 是否使用GPU
nz=100 # 噪声维度
ngf = 64 # 生成器feature map数
ndf = 64 # 判别器feature map数
save_path = 'imgs/' #生成图片保存路径
vis = True # 是否使用visdom可视化
env = 'GAN' # visdom的env
plot_every = 20 # 每间隔20 batch,visdom画图一次
debug_file = '/tmp/debuggan' # 存在该文件则进入debug模式
d_every = 1 # 每1个batch训练一次判别器
g_every = 5 # 每5个batch训练一次生成器
decay_every = 10 # 没10个epoch保存一次模型
netd_path = './checkpoints/netd_100.pth' # 'checkpoints/netd_.pth' #预训练模型
netg_path = './checkpoints/netg_100.pth' # 'checkpoints/netg_211.pth'
# 只测试不训练
gen_img = 'result.png'
# 从512张生成的图片中保存最好的64张
gen_num = 64
gen_search_num = 512
gen_mean = 0 # 噪声的均值
gen_std = 1 #噪声的方差
opt = Config()
训练
训练生成器网络
训练判别器网络
def train(**kwargs):
for k_,v_ in kwargs.items():
setattr(opt,k_,v_)
if opt.vis:
from visualize import Visualizer
vis = Visualizer(opt.env)
transforms = tv.transforms.Compose([
tv.transforms.Scale(opt.image_size),
tv.transforms.CenterCrop(opt.image_size),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = tv.datasets.ImageFolder(opt.data_path,transform=transforms)
dataloader = t.utils.data.DataLoader(dataset,
batch_size = opt.batch_size,
shuffle = True,
num_workers= opt.num_workers,
drop_last=True
)
# 定义网络
netg, netd = NetG(opt), NetD(opt)
map_location=lambda storage, loc: storage
if opt.netd_path:
netd.load_state_dict(t.load(opt.netd_path, map_location = map_location))
if opt.netg_path:
netg.load_state_dict(t.load(opt.netg_path, map_location = map_location))
# 定义优化器和损失
optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))
criterion = t.nn.BCELoss()
# 真图片label为1,假图片label为0
# noises为生成网络的输入
true_labels = Variable(t.ones(opt.batch_size))
fake_labels = Variable(t.zeros(opt.batch_size))
fix_noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))
noises = Variable(t.randn(opt.batch_size,opt.nz,1,1))
errord_meter = AverageValueMeter()
errorg_meter = AverageValueMeter()
if opt.gpu:
netd.cuda()
netg.cuda()
criterion.cuda()
true_labels,fake_labels = true_labels.cuda(), fake_labels.cuda()
fix_noises,noises = fix_noises.cuda(),noises.cuda()
epochs = range(opt.max_epoch)
for epoch in iter(epochs):
for ii,(img,_) in tqdm.tqdm(enumerate(dataloader)):
real_img = Variable(img)
if opt.gpu:
real_img=real_img.cuda()
if ii%opt.d_every==0:
# 训练判别器
optimizer_d.zero_grad()
## 尽可能的把真图片判别为正确
output = netd(real_img)
error_d_real = criterion(output,true_labels)
error_d_real.backward()
## 尽可能把假图片判别为错误
noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
fake_img = netg(noises).detach() # 根据噪声生成假图
output = netd(fake_img)
error_d_fake = criterion(output,fake_labels)
error_d_fake.backward()
optimizer_d.step()
error_d = error_d_fake + error_d_real
errord_meter.add(error_d.data.item())
if ii%opt.g_every==0:
# 训练生成器
optimizer_g.zero_grad()
noises.data.copy_(t.randn(opt.batch_size,opt.nz,1,1))
fake_img = netg(noises)
output = netd(fake_img)
error_g = criterion(output,true_labels)
error_g.backward()
optimizer_g.step()
errorg_meter.add(error_g.data.item())
if opt.vis and ii%opt.plot_every == opt.plot_every-1:
## 可视化
if os.path.exists(opt.debug_file):
ipdb.set_trace()
fix_fake_imgs = netg(fix_noises)
vis.images(fix_fake_imgs.data.cpu().numpy()[:64]*0.5+0.5,win='fixfake')
vis.images(real_img.data.cpu().numpy()[:64]*0.5+0.5,win='real')
vis.plot('errord',errord_meter.value()[0])
vis.plot('errorg',errorg_meter.value()[0])
if epoch%opt.decay_every==0:
# 保存模型、图片
tv.utils.save_image(fix_fake_imgs.data[:64],'%s/%s.png' %(opt.save_path,epoch),normalize=True,range=(-1,1))
t.save(netd.state_dict(),'checkpoints/netd_%s.pth' %epoch)
t.save(netg.state_dict(),'checkpoints/netg_%s.pth' %epoch)
errord_meter.reset()
errorg_meter.reset()
optimizer_g = t.optim.Adam(netg.parameters(),opt.lr1,betas=(opt.beta1, 0.999))
optimizer_d = t.optim.Adam(netd.parameters(),opt.lr2,betas=(opt.beta1, 0.999))
visualize.py
#coding:utf8
from itertools import chain
import visdom
import torch
import time
import torchvision as tv
import numpy as np
class Visualizer():
'''
封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`
调用原生的visdom接口
'''
def __init__(self, env='default', **kwargs):
import visdom
self.vis = visdom.Visdom(env=env, **kwargs)
# 画的第几个数,相当于横座标
# 保存(’loss',23) 即loss的第23个点
self.index = {}
self.log_text = ''
def reinit(self,env='default',**kwargs):
'''
修改visdom的配置
'''
self.vis = visdom.Visdom(env=env,**kwargs)
return self
def plot_many(self, d):
'''
一次plot多个
@params d: dict (name,value) i.e. ('loss',0.11)
'''
for k, v in d.iteritems():
self.plot(k, v)
def img_many(self, d):
for k, v in d.iteritems():
self.img(k, v)
def plot(self, name, y):
'''
self.plot('loss',1.00)
'''
x = self.index.get(name, 0)
self.vis.line(Y=np.array([y]), X=np.array([x]),
win=(name),
opts=dict(title=name),
update=None if x == 0 else 'append'
)
self.index[name] = x + 1
def img(self, name, img_):
'''
self.img('input_img',t.Tensor(64,64))
'''
if len(img_.size())<3:
img_ = img_.cpu().unsqueeze(0)
self.vis.image(img_.cpu(),
win=unicode(name),
opts=dict(title=name)
)
def img_grid_many(self,d):
for k, v in d.iteritems():
self.img_grid(k, v)
def img_grid(self, name, input_3d):
'''
一个batch的图片转成一个网格图,i.e. input(36,64,64)
会变成 6*6 的网格图,每个格子大小64*64
'''
self.img(name, tv.utils.make_grid(
input_3d.cpu()[0].unsqueeze(1).clamp(max=1,min=0)))
def log(self,info,win='log_text'):
'''
self.log({'loss':1,'lr':0.0001})
'''
self.log_text += ('[{time}] {info} <br>'.format(
time=time.strftime('%m%d_%H%M%S'),\
info=info))
self.vis.text(self.log_text,win='log_text')
def __getattr__(self, name):
return getattr(self.vis, name)
测试
输入: 1x100x1x1的噪声向量
输出: 1x3x96x96 的图像
def generate(**kwargs):
'''
随机生成动漫头像,并根据netd的分数选择较好的
'''
for k_,v_ in kwargs.items():
setattr(opt,k_,v_)
netg, netd = NetG(opt).eval(), NetD(opt).eval()
noises = t.randn(opt.gen_search_num,opt.nz,1,1).normal_(opt.gen_mean,opt.gen_std)
noises = Variable(noises, volatile=True)
map_location=lambda storage, loc: storage
print(opt.netd_path)
print(opt.netg_path)
netd.load_state_dict(t.load(opt.netd_path, map_location='cpu'))
netg.load_state_dict(t.load(opt.netg_path, map_location='cpu'))
# netd.load_state_dict(t.load(opt.netd_path, map_location= map_location))
# netg.load_state_dict(t.load(opt.netg_path, map_location= map_location))
if opt.gpu:
netd.cuda()
netg.cuda()
noises = noises.cuda()
# 生成图片,并计算图片在判别器的分数
fake_img = netg(noises)
scores = netd(fake_img).data
# 挑选最好的某几张
indexs = scores.topk(opt.gen_num)[1]
result = []
for ii in indexs:
result.append(fake_img.data[ii])
# 保存图片
tv.utils.save_image(t.stack(result),opt.gen_img,normalize=True,range=(-1,1))
生成的图像:
result.png