Pix2Pix for Image Translation

View Source On Gitee

Pix2Pix Overview

Pix2Pix is a deep learning image translation model implemented based on the condition generative adversarial network (cGAN). This model was proposed by Phillip Isola et al. at CVPR 2017, which can translate semantics/labels to real images, grayscale images to color images, aerial images to maps, daytime to nighttime, and line manuscript images to real images. Pix2Pix is a classic work of applying cGAN to supervised image-to-image translation. It consists of two models: generator and discriminator.

Traditionally, although the goal of such tasks is the same to predict pixels from pixels, each item is handled with a separate dedicated machine. The network used by Pix2Pix serves as a general framework, uses the same architecture and objectives, and trains only on different data to obtain satisfactory results, given that many people have used this network to publish their own works of art.

Basic Principles

There are some differences in principles between the cGAN generator and the traditional GAN generator. The cGAN generator uses the input image as the guidance information, and continuously attempts to generate a "fake" image for confusing the discriminator. The essence of converting an input image into a corresponding "fake" image is to map a pixel to another pixel. However, a traditional GAN generator generates an image based on a given random noise, and the output image is generated based on other constraints. This is the difference between cGAN and GAN in image translation tasks. The task of the discriminator in Pix2Pix is to determine whether the image output from the generator is a real training image or a generated "fake" image. During the continuous game between the generator and the discriminator, the model reaches a balance point. The image output by the generator and the real training data make the discriminator have a 50% probability of correct judgment.

Before this tutorial, you need to define some symbols to be used in the entire process:

  • \(x\): indicates the data of the observed image.

  • \(z\): indicates the data of random noise.

  • \(y=G(x,z)\): indicates the generator network, which gives a "fake" image generated by the observed image \(x\) and random noise \(z\), where \(x\) comes from the training data rather than the generator.

  • \(D(x,G(x,z))\): indicates the discriminator network, which provides the probability that an image is determined as a real image. \(x\) comes from the training data, and \(G(x,z)\) comes from the generator.

The objectives of cGAN can be expressed as follows:

\[L_{cGAN}(G,D)=E_{(x,y)}[log(D(x,y))]+E_{(x,z)}[log(1-D(x,G(x,z)))]\]

The formula is a loss function of cGAN. D tries to correctly classify real images and "fake" images, that is, maximize the parameter \(log D(x,y)\). G tries to deceive D with the generated "fake" image \(y\), that is, minimize the value of \(log(1-D(G(x,z)))\). The objectives of cGAN can be simplified as follows:

\[arg\min_{G}\max_{D}L_{cGAN}(G,D)\]

pix2pix1

To compare the differences between cGAN and GAN, the objectives of GAN can be expressed as follows:

\[L_{GAN}(G,D)=E_{y}[log(D(y))]+E_{(x,z)}[log(1-D(x,z))]\]

It can be seen from the formula that GAN directly generates a "fake" image from random noise \(z\) without using any information of the observed image \(x\). Previous approaches have found it beneficial to mix the GAN objective with a more traditional loss. The task of the discriminator remains unchanged, that is, to distinguish real images from "fake" images. However, the task of the generator is not only to deceive the discriminator, it is also close to the training data on the basis of traditional losses. Assume that cGAN and L1 regularization are used together, then:

\[L_{L1}(G)=E_{(x,y,z)}[||y-G(x,z)||_{1}]\]

Our final objective is:

\[arg\min_{G}\max_{D}L_{cGAN}(G,D)+\lambda L_{L1}(G)\]

The image translation problem is essentially the pixel-to-pixel mapping problem. Pix2Pix uses the same network structure and objective function. The preceding tasks can be implemented by replacing different training datasets. This task uses the MindSpore framework to implement the Pix2Pix application.

Preparations

Configuring the Environment File

You can run this case in either dynamic or static mode on the GPU, CPU, or Ascend platform.

Preparing Data

In this tutorial, we will use specified dataset which is processed facades data and can be directly read using the mindspore.dataset method.

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"

download(url, "./dataset", kind="tar", replace=True)

Data Display

Call Pix2PixDataset and create_train_dataset to read the training set. Here, we directly download the processed dataset.

from mindspore import dataset as ds
import matplotlib.pyplot as plt

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# Visualize some training data.
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

Creating a Network

After the data is processed, you can set up the network. The generator, discriminator, and loss function will be discussed in detail one by one. Generator G uses the U-Net structure. The input contour map \(x\) is encoded and then decoded into a real image. Discriminator D uses the condition discriminator PatchGAN proposed by the author. A function of the discriminator D is to determine, under a condition of the contour map \(x\), that the generated image \(G(x)\) is false, and that the generated image \(G(x)\) is true.

Generator G Structure

U-Net is a fully convolutional structure proposed by the pattern recognition and image processing team of University of Freiburg in Germany. It is divided into two parts. The left part is the compression path formed by convolution and downsampling operations, and the right part is the expansion path formed by convolution and upsampling. The input of each expanded network block is formed by combining the features sampled at the upper layer and the features of the compression path part. The network model is a U-shaped structure and therefore is called U-Net. Compared with the common network where the sampling is reduced to a low dimension and then increased to the original resolution, the U-Net adds skip-connection. The corresponding feature maps and the decoded feature maps of the same size are combined by channel. It is used to reserve pixel-level details at different resolutions.

pix2pix2

Defining the U-Net Skip Connection Block

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.BatchNorm2d(inner_nc, affine=False)
            up_norm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
                              stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv]
            up = [up_relu, up_conv, up_norm]
            model = down + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            up = [up_relu, up_conv, up_norm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out

U-Net-based Generator

class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

The input of the original cGAN is two types of information: condition x and noise z. The generator here uses only the condition information. Therefore, the generator cannot generate diversified results. Pix2Pix uses dropout during training and testing. In this way, diversified results can be generated.

PatchGAN-based Discriminator

PatchGAN structure used by the discriminator can be considered as convolution. Each point in the generated matrix represents a patch of the original image. The values in the matrix are used to determine whether each patch in the original image is true or false.

import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

Initialization of Pix2Pix Generator and Discriminator

Instantiate the Pix2Pix generator and discriminator.

import mindspore.nn as nn
from mindspore.common import initializer as init

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'


net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
                              ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))


net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
                                  alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):
    """Pix2Pix model network"""
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.net_discriminator = discriminator
        self.net_generator = generator

    def construct(self, reala):
        fakeb = self.net_generator(reala)
        return fakeb

Training

Training is divided into two parts: discriminator training and generator training. The discriminator is trained to improve the probability of discriminating real images to the greatest extent. The training generator is expected to produce better fake images. In the two parts, the losses in the training process are obtained separately, and statistics are collected at the end of each epoch.

The training process is as follows:

import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor

epoch_num = 100
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100

def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=16)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

def forword_dis(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis

def forword_gan(reala, realb):
    lambda_gan = 0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    loss_1 = loss_f(pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
    return loss_gan

d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)

grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d(reala, realb)
    loss_gan, g_grads = grad_g(reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan

if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)

for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1) == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")
ms per step:532.31   epoch:1/100  step:0/25  Dloss:0.6940  Gloss:38.1245
ms per step:304.35   epoch:1/100  step:2/25  Dloss:0.6489  Gloss:39.4826
ms per step:299.15   epoch:1/100  step:4/25  Dloss:0.5506  Gloss:36.7634
ms per step:301.06   epoch:1/100  step:6/25  Dloss:1.6741  Gloss:47.7600
ms per step:299.72   epoch:1/100  step:8/25  Dloss:0.4604  Gloss:39.7121
......                                                            ......
ms per step:290.44   epoch:100/100  step:16/25  Dloss:0.6009  Gloss:9.1915
ms per step:289.95   epoch:100/100  step:18/25  Dloss:0.4617  Gloss:9.8740
ms per step:290.24   epoch:100/100  step:20/25  Dloss:0.4402  Gloss:8.2490
ms per step:287.70   epoch:100/100  step:22/25  Dloss:0.3814  Gloss:9.3652
ms per step:289.41   epoch:100/100  step:24/25  Dloss:0.4199  Gloss:9.2418

Inference

Obtain the CKPT file after the preceding training process is complete, import the weight parameters in the CKPT file to the model by using load_checkpoint and load_param_into_net, obtain data for inference, and demonstrate the inference effect. (Only 100 epochs are performed during the training process.)

from mindspore import load_checkpoint, load_param_into_net

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 10, i + 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

The inference effect of each dataset is as follows:

pix2pix3

Reference

[1] Phillip Isola,Jun-Yan Zhu,Tinghui Zhou,Alexei A. Efros. Image-to-Image Translation with Conditional Adversarial Networks.[J]. CoRR,2016,abs/1611.07004.