Limtsekheng's Lab

place for study, research and random stuff

0%

DRIT model summary

Diverse Image-to-Image Translation via Disentangled Representations (DRIT)
Hsin-Ying Lee, Hung-Yu Tseng, Jia-Bin Huang, Maneesh Singh, Ming-Hsuan Yang (ECCV2018)
The original paper is here.
Official Github Repo is here under Python3.5-3.6, Pytorch0.4.0 and CUDA9
My modified version is here, compatibale under Python3.7.10, Pytorch1.8.0 and CUDA11.0

Objectives

  • Aim to disentangle content and attribute from images
  • Use the disentangled representation for producing diverse outputs

Target

  • Content space is shared among domains
  • Attribute space encodes intra-domain variations

To make content latent zcx and zcy share a same content space C:

  • Share weight between the last layer of Ecx and Ecy and first layer of Gx and Gy so the higher level features will be mapped to same space (but still cannot guarantee the same content representations encode the same information for both domains
  • Use content discriminator Dc for adversarial training


Loss functions

Content adversarial loss:

  • To encourage content latent from x and y share a same content space C

Cross-cycle consistency loss

  • Forward translation and backward translation
  • To encourage the reconstruction after two Image-to-Image translations

Self-reconstruction loss:

  • To facilitate the training of encoder and decoder in addition to the cross-cycle reconstruction

Domain adversarial loss:

  • Domain discriminator Dx and Dy attempt to discriminate between real images and generated images in each domain
  • To encourage Gx and Gy to generate realistic images

KL loss:

  • To encourage the attribute representation to be close to a prior Gaussian distribution
  • For easy sampling at test time

Latent regression loss:

  • To encourage the invertible mapping between the image and the attribute latent space
  • Attempt to reconstruct the attribute representation drawn from the prior Gaussian distribution

Full objective function:

Code implementation

I have summarised the architecture of the official implementation together with the key variable names in following diagram.

The official released code is implemented under old Pytorch version and CUDA9, so if your GPU does not support CUDA9 and have no choice but to run using later Pytorch version, you most probably will encounter some Runtime Error related to in-place operations. The reason is that in old Pytorch versions (at least up to 0.4.0), node version mismatch between its forward and backward is not detected even though this should not be allowed. In later Pytorch version this issue is fixed and it will produce error explicitly.

If we look at the original code in model.py, we can see that the second backward operation is performed straightly after the first backward and step operation. So when self.backward_G_alone() is performed it will detect that nodes along its propagation path have different versions compared to that in forward propagation time and in-place operation error pops up here.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def update_EG(self):
# update G, Ec, Ea
self.enc_c_opt.zero_grad()
self.enc_a_opt.zero_grad()
self.gen_opt.zero_grad()
self.backward_EG() # first backward to obtain gradient values
self.enc_c_opt.step() # update node values
self.enc_a_opt.step() # update node values
self.gen_opt.step() # update node values

# update G, Ec
self.enc_c_opt.zero_grad()
self.gen_opt.zero_grad()
self.backward_G_alone() # another backward to obtain gradient values
self.enc_c_opt.step()
self.gen_opt.step()

The solution here is also simple, just bring the self.backward_G_alone() call before the step operation. You can refer to my folked repo with modified code here also.

1
2
3
4
5
6
7
8
9
10
def update_EG(self):
# update G, Ec, Ea
self.enc_c_opt.zero_grad()
self.enc_a_opt.zero_grad()
self.gen_opt.zero_grad()
self.backward_EG()
self.backward_G_alone() # backward_G_alone here to accumulate gradient together before update
self.enc_c_opt.step()
self.enc_a_opt.step()
self.gen_opt.step()

Results after 1200 epoches using cat2dog dataset: