Pix2Pix for Image Translation
Pix2Pix Overview
Pix2Pix is a deep learning image translation model implemented based on the condition generative adversarial network (cGAN). This model was proposed by Phillip Isola et al. at CVPR 2017, which can translate semantics/labels to real images, grayscale images to color images, aerial images to maps, daytime to nighttime, and line manuscript images to real images. Pix2Pix is a classic work of applying cGAN to supervised image-to-image translation. It consists of two models: generator and discriminator.
Traditionally, although the goal of such tasks is the same to predict pixels from pixels, each item is handled with a separate dedicated machine. The network used by Pix2Pix serves as a general framework, uses the same architecture and objectives, and trains only on different data to obtain satisfactory results, given that many people have used this network to publish their own works of art.
Basic Principles
There are some differences in principles between the cGAN generator and the traditional GAN generator. The cGAN generator uses the input image as the guidance information, and continuously attempts to generate a "fake" image for confusing the discriminator. The essence of converting an input image into a corresponding "fake" image is to map a pixel to another pixel. However, a traditional GAN generator generates an image based on a given random noise, and the output image is generated based on other constraints. This is the difference between cGAN and GAN in image translation tasks. The task of the discriminator in Pix2Pix is to determine whether the image output from the generator is a real training image or a generated "fake" image. During the continuous game between the generator and the discriminator, the model reaches a balance point. The image output by the generator and the real training data make the discriminator have a 50% probability of correct judgment.
Before this tutorial, you need to define some symbols to be used in the entire process:
\(x\): indicates the data of the observed image.
\(z\): indicates the data of random noise.
\(y=G(x,z)\): indicates the generator network, which gives a "fake" image generated by the observed image \(x\) and random noise \(z\), where \(x\) comes from the training data rather than the generator.
\(D(x,G(x,z))\): indicates the discriminator network, which provides the probability that an image is determined as a real image. \(x\) comes from the training data, and \(G(x,z)\) comes from the generator.
The objectives of cGAN can be expressed as follows:
The formula is a loss function of cGAN. D
tries to correctly classify real images and "fake" images, that is, maximize the parameter \(log D(x,y)\). G
tries to deceive D
with the generated "fake" image \(y\), that is, minimize the value of \(log(1-D(G(x,z)))\). The objectives of cGAN can be simplified as follows:
To compare the differences between cGAN and GAN, the objectives of GAN can be expressed as follows:
It can be seen from the formula that GAN directly generates a "fake" image from random noise \(z\) without using any information of the observed image \(x\). Previous approaches have found it beneficial to mix the GAN objective with a more traditional loss. The task of the discriminator remains unchanged, that is, to distinguish real images from "fake" images. However, the task of the generator is not only to deceive the discriminator, it is also close to the training data on the basis of traditional losses. Assume that cGAN and L1 regularization are used together, then:
Our final objective is:
The image translation problem is essentially the pixel-to-pixel mapping problem. Pix2Pix uses the same network structure and objective function. The preceding tasks can be implemented by replacing different training datasets. This task uses the MindSpore framework to implement the Pix2Pix application.
Preparations
Configuring the Environment File
You can run this case in either dynamic or static mode on the GPU, CPU, or Ascend platform.
Preparing Data
In this tutorial, we will use specified dataset which is processed facades data and can be directly read using the mindspore.dataset method.
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"
download(url, "./dataset", kind="tar", replace=True)
Data Display
Call Pix2PixDataset
and create_train_dataset
to read the training set. Here, we directly download the processed dataset.
from mindspore import dataset as ds
import matplotlib.pyplot as plt
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# Visualize some training data.
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
plt.subplot(3, 10, i)
plt.axis("off")
plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()
Creating a Network
After the data is processed, you can set up the network. The generator, discriminator, and loss function will be discussed in detail one by one. Generator G uses the U-Net structure. The input contour map \(x\) is encoded and then decoded into a real image. Discriminator D uses the condition discriminator PatchGAN proposed by the author. A function of the discriminator D is to determine, under a condition of the contour map \(x\), that the generated image \(G(x)\) is false, and that the generated image \(G(x)\) is true.
Generator G Structure
U-Net is a fully convolutional structure proposed by the pattern recognition and image processing team of University of Freiburg in Germany. It is divided into two parts. The left part is the compression path formed by convolution and downsampling operations, and the right part is the expansion path formed by convolution and upsampling. The input of each expanded network block is formed by combining the features sampled at the upper layer and the features of the compression path part. The network model is a U-shaped structure and therefore is called U-Net. Compared with the common network where the sampling is reduced to a low dimension and then increased to the original resolution, the U-Net adds skip-connection. The corresponding feature maps and the decoded feature maps of the same size are combined by channel. It is used to reserve pixel-level details at different resolutions.
Defining the U-Net Skip Connection Block
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
class UNetSkipConnectionBlock(nn.Cell):
def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
super(UNetSkipConnectionBlock, self).__init__()
down_norm = nn.BatchNorm2d(inner_nc)
up_norm = nn.BatchNorm2d(outer_nc)
use_bias = False
if norm_mode == 'instance':
down_norm = nn.BatchNorm2d(inner_nc, affine=False)
up_norm = nn.BatchNorm2d(outer_nc, affine=False)
use_bias = True
if in_planes is None:
in_planes = outer_nc
down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
down_relu = nn.LeakyReLU(alpha)
up_relu = nn.ReLU()
if outermost:
up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, pad_mode='pad')
down = [down_conv]
up = [up_relu, up_conv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, has_bias=use_bias, pad_mode='pad')
down = [down_relu, down_conv]
up = [up_relu, up_conv, up_norm]
model = down + up
else:
up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, has_bias=use_bias, pad_mode='pad')
down = [down_relu, down_conv, down_norm]
up = [up_relu, up_conv, up_norm]
model = down + [submodule] + up
if dropout:
model.append(nn.Dropout(p=0.5))
self.model = nn.SequentialCell(model)
self.skip_connections = not outermost
def construct(self, x):
out = self.model(x)
if self.skip_connections:
out = ops.concat((out, x), axis=1)
return out
U-Net-based Generator
class UNetGenerator(nn.Cell):
def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
super(UNetGenerator, self).__init__()
unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
norm_mode=norm_mode, innermost=True)
for _ in range(n_layers - 5):
unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
norm_mode=norm_mode, dropout=dropout)
unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
norm_mode=norm_mode)
self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
outermost=True, norm_mode=norm_mode)
def construct(self, x):
return self.model(x)
The input of the original cGAN is two types of information: condition x and noise z. The generator here uses only the condition information. Therefore, the generator cannot generate diversified results. Pix2Pix uses dropout during training and testing. In this way, diversified results can be generated.
PatchGAN-based Discriminator
PatchGAN structure used by the discriminator can be considered as convolution. Each point in the generated matrix represents a patch of the original image. The values in the matrix are used to determine whether each patch in the original image is true or false.
import mindspore.nn as nn
class ConvNormRelu(nn.Cell):
def __init__(self,
in_planes,
out_planes,
kernel_size=4,
stride=2,
alpha=0.2,
norm_mode='batch',
pad_mode='CONSTANT',
use_relu=True,
padding=None):
super(ConvNormRelu, self).__init__()
norm = nn.BatchNorm2d(out_planes)
if norm_mode == 'instance':
norm = nn.BatchNorm2d(out_planes, affine=False)
has_bias = (norm_mode == 'instance')
if not padding:
padding = (kernel_size - 1) // 2
if pad_mode == 'CONSTANT':
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
has_bias=has_bias, padding=padding)
layers = [conv, norm]
else:
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
pad = nn.Pad(paddings=paddings, mode=pad_mode)
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
layers = [pad, conv, norm]
if use_relu:
relu = nn.ReLU()
if alpha > 0:
relu = nn.LeakyReLU(alpha)
layers.append(relu)
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class Discriminator(nn.Cell):
def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
super(Discriminator, self).__init__()
kernel_size = 4
layers = [
nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
nn.LeakyReLU(alpha)
]
nf_mult = ndf
for i in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** i, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
self.features = nn.SequentialCell(layers)
def construct(self, x, y):
x_y = ops.concat((x, y), axis=1)
output = self.features(x_y)
return output
Initialization of Pix2Pix Generator and Discriminator
Instantiate the Pix2Pix generator and discriminator.
import mindspore.nn as nn
from mindspore.common import initializer as init
g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'
net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
if init_type == 'normal':
cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
elif init_type == 'xavier':
cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
elif init_type == 'constant':
cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
class Pix2Pix(nn.Cell):
"""Pix2Pix model network"""
def __init__(self, discriminator, generator):
super(Pix2Pix, self).__init__(auto_prefix=True)
self.net_discriminator = discriminator
self.net_generator = generator
def construct(self, reala):
fakeb = self.net_generator(reala)
return fakeb
Training
Training is divided into two parts: discriminator training and generator training. The discriminator is trained to improve the probability of discriminating real images to the greatest extent. The training generator is expected to produce better fake images. In the two parts, the losses in the training process are obtained separately, and statistics are collected at the end of each epoch.
The training process is as follows:
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor
epoch_num = 100
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100
def get_lr():
lrs = [lr] * dataset_size * n_epochs
lr_epoch = 0
for epoch in range(n_epochs_decay):
lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
lrs += [lr_epoch] * dataset_size
lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
return Tensor(np.array(lrs).astype(np.float32))
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=16)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
def forword_dis(reala, realb):
lambda_dis = 0.5
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
pred1 = net_discriminator(reala, realb)
loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
loss_dis = loss_d * lambda_dis
return loss_dis
def forword_gan(reala, realb):
lambda_gan = 0.5
lambda_l1 = 100
fakeb = net_generator(reala)
pred0 = net_discriminator(reala, fakeb)
loss_1 = loss_f(pred0, ops.ones_like(pred0))
loss_2 = l1_loss(fakeb, realb)
loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
return loss_gan
d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
beta1=0.5, beta2=0.999, loss_scale=1)
grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())
def train_step(reala, realb):
loss_dis, d_grads = grad_d(reala, realb)
loss_gan, g_grads = grad_g(reala, realb)
d_opt(d_grads)
g_opt(g_grads)
return loss_dis, loss_gan
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)
for epoch in range(epoch_num):
for i, data in enumerate(data_loader):
start_time = datetime.datetime.now()
input_image = Tensor(data["input_images"])
target_image = Tensor(data["target_images"])
dis_loss, gen_loss = train_step(input_image, target_image)
end_time = datetime.datetime.now()
delta = (end_time - start_time).microseconds
if i % 2 == 0:
print("ms per step:{:.2f} epoch:{}/{} step:{}/{} Dloss:{:.4f} Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
d_losses.append(dis_loss.asnumpy())
g_losses.append(gen_loss.asnumpy())
if (epoch + 1) == epoch_num:
mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")
ms per step:532.31 epoch:1/100 step:0/25 Dloss:0.6940 Gloss:38.1245
ms per step:304.35 epoch:1/100 step:2/25 Dloss:0.6489 Gloss:39.4826
ms per step:299.15 epoch:1/100 step:4/25 Dloss:0.5506 Gloss:36.7634
ms per step:301.06 epoch:1/100 step:6/25 Dloss:1.6741 Gloss:47.7600
ms per step:299.72 epoch:1/100 step:8/25 Dloss:0.4604 Gloss:39.7121
...... ......
ms per step:290.44 epoch:100/100 step:16/25 Dloss:0.6009 Gloss:9.1915
ms per step:289.95 epoch:100/100 step:18/25 Dloss:0.4617 Gloss:9.8740
ms per step:290.24 epoch:100/100 step:20/25 Dloss:0.4402 Gloss:8.2490
ms per step:287.70 epoch:100/100 step:22/25 Dloss:0.3814 Gloss:9.3652
ms per step:289.41 epoch:100/100 step:24/25 Dloss:0.4199 Gloss:9.2418
Inference
Obtain the CKPT file after the preceding training process is complete, import the weight parameters in the CKPT file to the model by using load_checkpoint and load_param_into_net, obtain data for inference, and demonstrate the inference effect. (Only 100 epochs are performed during the training process.)
from mindspore import load_checkpoint, load_param_into_net
param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
plt.subplot(2, 10, i + 1)
plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
plt.axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.subplot(2, 10, i + 11)
plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
plt.axis("off")
plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()
The inference effect of each dataset is as follows:
Reference
[1] Phillip Isola,Jun-Yan Zhu,Tinghui Zhou,Alexei A. Efros. Image-to-Image Translation with Conditional Adversarial Networks.[J]. CoRR,2016,abs/1611.07004.