FIGR: Few-shot Image Generation with Reptile

  • 2019-01-08 08:15:08
  • Louis Clouâtre, Marc Demers
  • 77

Abstract

Generative Adversarial Networks (GAN) boast impressive capacity to generaterealistic images. However, like much of the field of deep learning, theyrequire an inordinate amount of data to produce results, thereby limiting theirusefulness in generating novelty. In the same vein, recent advances inmeta-learning have opened the door to many few-shot learning applications. Inthe present work, we propose Few-shot Image Generation using Reptile (FIGR), aGAN meta-trained with Reptile. Our model successfully generates novel images onboth MNIST and Omniglot with as little as 4 images from an unseen class. Wefurther contribute FIGR-8, a new dataset for few-shot image generation, whichcontains 1,548,944 icons categorized in over 18,409 classes. Trained on FIGR-8,initial results show that our model can generalize to more advanced concepts(such as "bird" and "knife") from as few as 8 samples from a previously unseenclass of images and as little as 10 training steps through those 8 images. Thiswork demonstrates the potential of training a GAN for few-shot image generationand aims to set a new benchmark for future work in the domain.

 

Quick Read (beta)

FIGR: Few-shot Image Generation with Reptile

Louis Clouâtre
École Polytechnique de Montréal
[email protected]
   Marc Demers
McGill University
[email protected]
Abstract

Generative Adversarial Networks (GAN) boast impressive capacity to generate realistic images. However, like much of the field of deep learning, they require an inordinate amount of data to produce results, thereby limiting their usefulness in generating novelty. In the same vein, recent advances in meta-learning have opened the door to many few-shot learning applications. In the present work, we propose Few-shot Image Generation using Reptile (FIGR), a GAN meta-trained with Reptile. Our model successfully generates novel images on both MNIST and Omniglot with as little as 4 images from an unseen class. We further contribute FIGR-8, a new dataset for few-shot image generation, which contains 1,548,944 icons categorized in over 18,409 classes. Trained on FIGR-8, initial results show that our model can generalize to more advanced concepts (such as “bird” and “knife”) from as few as 8 samples from a previously unseen class of images and as little as 10 training steps through those 8 images. This work demonstrates the potential of training a GAN for few-shot image generation and aims to set a new benchmark for future work in the domain.

1 Introduction

Generative Adversarial Networks [7] have helped bridge the gap between human and artificial intelligence with regard to understanding and manipulating images. GANs however require several orders of magnitude more data points than humans in order to generate comprehensible images successfully from a given class of images. This impairs the ability of GANs to generate novelty. In many cases, if the data is abundant enough to successfully train a GAN, there is little purpose to generating more of this data.

On the other hand, recent advances in meta-learning, like the MAML [6] and Reptile [15] algorithms, have allowed learning tasks to perform well on novel data sampled from the same distribution as the training data. These meta-learning algorithms have seen direct applications in supervised and reinforcement learning, but not in image generation. Being very general in their application, those algorithms may be applicable to few-shot image generation. This paper defines the problem of few-shot image generation, and introduces an approach to GAN training for Few-shot Image Generation with Reptile (FIGR). In addition, this paper introduces FIGR-8, a dataset of 1,548,944 black-and-white pictograms, ideograms, icons, emoticons, object or conception depictions categorized in 18,409 classes. We contribute this dataset as a challenging benchmark for one- and few-shot image generation approaches. Following training, our approach is able to correctly generate images from a class of images with as few as 4 samples from the previously unseen class.

In summary, our main contributions are:

  • We develop a novel approach for training GANs for few-shot image generation.

  • We contribute a challenging dataset for that same task.

The applications of few-shot image generation are broad, but we mainly foresee this approach to provide assistance in creative processes. Artists or designers who lack time or creative inspiration for multiple versions of an image could sketch a limited number of drawings and have the trained model generate multiple similar versions of the sketches.

2 Related work

2.1 Meta-learning

MAML is currently the most widely used approach for few-shot meta-learning. Several variant of the algorithm exist. They all have conditions that make them ill-fitting for meta-training a GAN. First, they rely on the direction of the loss function to be linked with the quality of the model. For GAN’s this assumption cannot be made. Second, they rely on being able to evaluate performance on a test set for training. There is no clear way to do that for GAN.

2.2 Few-Shot Image Generation

To our knowledge, Lake et al. (2015) [13] provides the first successful attempt at one-shot or few-shot image generation. To achieve this on the Omniglot dataset introduced in the same paper, both the images and stroke data are used to train a Bayesian model through Bayesian Program Learning. It represents concepts, such as a pen stroke, as simple probabilistic programs and hierarchically combines them to generate images. This yields a model that can be trained on a single image of a previously unseen letter and generate novel samples of the same letter. It generates binary images.

Rezende et al. (2016) [17] uses a sequential generative model to achieve one-shot generation. The inference process uses an attention [3] module to have a Variational Auto Encoder [12] attend to a section of the generated image sequentially. Unlike in Lake et al. (2015), it trains on pure image data (without requiring stroke data), making this approach much more general. It generates binary images of size 28×28 and 52×52 on the Omniglot dataset with one-shot learning.

Bartunov and Vetrov (2018) [4] uses matching networks to achieve few-shot image generation. In essence, matching networks [18] are memory-assisted networks that leverage an external memory by employing an attention [3] module to quickly learn new concepts. It assumes that the concepts stored are somewhat similar to the new out-of-sample concepts. This approach is equally trained on pure image data and does not require a lengthy sequential inference period. It generates binary images of size 28×28 on the Omniglot dataset using few-shot learning.

Several issues can be found with the aforementioned approaches that no prior work seems to address:

  • The use of small binary images for all generative models seem to imply scalability issues.

  • Limitations to the Omniglot dataset for one- and few-shot image generation. This dataset has several issues that will be expanded up in Section 2.3

  • None of the approaches have use an architecture that has shown the potential to generate highly realistic images like GANs have.

2.3 Omniglot

The Omniglot dataset [13] is the current baseline dataset for the one- or few-shot image generation task. Details about the dataset can be found in Section 4.2. There are two main issues with using this dataset as a benchmark.

  • All classes within the dataset are very similar. They all represent roughly the same concept– a character.

  • The classes lack complexity. All classes in Omniglot are simple handwritten characters that can be explained and generated through the composition of learned pen strokes [13].

We believe that a proper image generation benchmark should encompass a greater variety of classes and more complex classes to have real-life applications or the hope of applications on natural images.

3 Few-shot Image Generation with Reptile

Generative Adversarial Networks GANs are generative models that learn a generator network G to map a random noise vector z to an image y, such that G(z)=y. To accomplish this, we use a discriminator network D and real images from the distribution we want to generate from x. D is trained on both x and y to be able to distinguish the ”fake” images y from the ”real” images x while G is trained to fool D. This adversarial game played between the two models leads to G being able to generate images that resemble the ones from x [7].

Few-shot image generation We define the few-shot image generation problem with the help of the meta-learning problem set-up found in Finn et al. (2017) [6] and Nichol et al. (2018) [15]. In this problem we assume access to a set of tasks T containing multiple task τ where each individual task τ is an image generation problem with one class of images Xτ and a loss Lτ. We define Lτ the ability of a human to discriminate between a group of generated images and a group of real images sampled from task Xτ as described in Lake et al. (2015) [13]. We do not conduct human benchmarking in this paper as this will be part of follow up work. We however leave it in the task description as we believe it is essential for a proper metric to exist.

The aim is to find, through meta-training, parameters Φ, that can quickly, meaning with little data and little training, converge on a random task τ to minimize an associated loss Lτ.

In essence, we want to:

minimizeΦ𝔼τ[Lτ(Uτk(Φ))] (1)

where Uτk(Φ) is the operator that updates Φ k times using xn, a total of n data points sampled from Xτ [15].

MNIST As an example, the MNIST dataset contains 10 classes (the 10 digits). In the few-shot image generation problem, they represent 10 tasks to solve, τ0 to τ9. We choose τ0 to τ8 to be the training task and τ9 to be the test task. Through meta-training on τ0 to τ8, we aim to obtain a set of parameters Φ that will quickly converge on a new τ. We choose n to be 4, meaning that we aim for our meta-trained Φ to converge to generating images of 9’s with only 4 images sampled from τ9.

FIGR In FIGR, Φ corresponds to both the generator network G and the discriminator network D. U corresponds to one step of Stochastic Gradient Descent [5] on D and G using Wasserstein loss [1] with gradient-penalty [8].

The adapted Reptile pseudo code for meta-training the model is depicted in Algorithm 3. The algorithm is composed of an outer loop and an inner loop. The inner loop is the K step of the operator U on a copy of the parameters Φ with task τ. Once we have those adapted weight Wτ, we can proceed to the outer loop. We set the gradient of Φ to be equal to Φ-Wτ. We then take one step with the Adam optimizer [11].

{algorithm} [H] Algorithm 1: FIGR training {algorithmic}[1] \StateInitialize Φd, the discriminator parameter vector \StateInitialize Φg, the generator parameter vector \Foriteration 1, 2, 3 … \StateMake a copy of Φd resulting in Wd \StateMake a copy of Φg resulting in Wg \StateSample task τ \StateSample n images from Xτ resulting xτ \ForK>1 iterations \StateGenerate latent vector z \StateGenerate fake images y with z and Wg \StatePerform step of SGD update on Wd with Wasserstein GP loss and xτ and y \StateGenerate latent vector z \StatePerform step of SGD update on Wg with Wasserstein loss and z \EndFor\StateSet Φd gradient to be Φd - Wd \StatePerform step of Adam update on Φd \StateSet Φg gradient to be Φg - Wg \StatePerform step of Adam update on Φg \EndFor

Once meta-trained, we use a similar process to generate novel images from the sampled class described in Algorithm 3.

{algorithm} [H] Algorithm 2: FIGR generation {algorithmic}[1] \StateUsing Wd, a copy of the meta-trained Φd \StateUsing Wg, a copy of the meta-trained Φg \StateSample test task τ \StateSample n images as xτ from Xτ \ForK1 iterations \StateGenerate latent vector z \StateGenerate fake images y with z and Wg \StatePerform step of SGD update on Wd with Wasserstein GP loss and xτ and y \StateGenerate latent vector z \StatePerform step of SGD update on Wg with Wasserstein loss and z \EndFor\StateGenerate latent vector z \StateGenerate fake images y

For every task τ there exist optimal discriminator and generator weights Wdτ and Wgτ. Intuitively, Reptile initializes the weights Φd and Φg to the point in parameter space that minimizes the distance between Φd, Φg, Wdτ and Wgτ for all τ, or

minimizeT(Φd-Wdτ)+(Φg-Wgτ) (2)

Hence, for a sampled task τ, a model optimized with Reptile can quickly and with few data points converge to the optimal point Wdτ, Wgτ from Φd, Φg. If the test tasks are close enough to the training task and if the training tasks are numerous enough, Φd and Φg are likely to be close to a test τ’s Wdτ and Wgτ. This makes for rapid and easy generalization from few data points.

Figure 1: Sample taken from the FIGR-8 dataset. Items from 120 out of 18,409 classes are displayed and one class (cow) is (non-extensively) detailed

Reptile is broadly similar to joint training, and is effectively identical with a K of 1. However, by doing more gradient steps, we prioritize learning features that would be hard to reach, unlike joint training. Assuming a 2D parameter space, a K of 10 and a task τ; a local minimum for parameter 1, Wτ1, is reached after 2 gradient steps and a local minimum for the parameter 2, Wτ2, is not reached after K steps; it is probable that:

Φ1-Wτ1<Φ2-Wτ2 (3)

This would result in a larger outer loop update in the parameter space that is not readily attainable from Φ and smaller updates in the parameter space in which the model already possesses the ability to converge quickly.

4 Datasets

4.1 MNIST

MNIST [14] is the first dataset chosen as its simplicity allows us to iterate quickly through model ideas. The MNIST dataset contains 28×28 grayscale images from the 10 digits. We use the 60,000 training set images for all experiments.

4.2 Omniglot

Omniglot [13] is arguably the de facto dataset for few-shot image generation. It contains 1623 unique type of characters originating from 50 alphabets, each of which has been handwritten 1 time by 20 different individuals. Contrarily to MNIST, Omniglot allows for training our model on a much larger amount of classes of images, and test the out-of-sample performance of the model on a wider set of classes.

4.3 FIGR-8

For the sake of testing the limits of our model, we compiled 1,548,944 images separated in 18,409 conceptually different classes, a set of data which we named FIGR-8. Each class contains at least 8 images, up to a few thousands. The icons are black-and-white representations of objects, concepts, patterns or designs that have been created by designers and artists and compiled into one data set. 120 classes out of 18,409 are pictured in Figure 1. Each of those classes containing at least 8 images of a similar theme. Every image is of square format 1×200×200. The relative cumulative density of classes in the database is represented in Figure 2.

Figure 2: Relative cumulative density of the number of elements in each class in the FIGR-8 dataset

We expect this dataset to be more challenging for training the meta-learning model, as it contains a wide variety of samples inside each class and a substantial amount of classes. Hopefully, the large amount of classes will let the model quickly understand the underlying concept even if every sample from a class does not represent the class’ concept in the same manner. Some icons do have complex patterns and details, which poses a greater challenge than the existing datasets for one- or few-shot image generation tasks. All in all, the FIGR-8 dataset constitutes a tough yet achievable benchmark for few-shot image generation tasks.

5 Experiments

5.1 Model architecture

All models have been trained with Wasserstein loss [1] with gradient-penalty [8]. We have found that a simple DCGAN [16] with a binary cross-entropy loss trained with this setup yielded positive results on MNIST [14]. More complex datasets, such as Omniglot [13] and FIGR-8, were more challenging and required this loss function for the model to succeed. Both the generator and the discriminator are built with residual neural networks [10] with 18 layers. The discriminator uses layer normalization [2] as prescribed in Gulrajani et al. (2017) [8]. The generator also uses layer normalization since batch normalization requires running statistics which are incompatible with Reptile’s meta-update.

All rectified linear units are Parametric ReLU [9] (PReLU). PReLU is the authors’ preferred rectified linear activation function. However, any other rectified linear activation function should yield comparable results.

All images are resized with bilinear interpolation to 32×32 or 64×64. All images are in grayscale format and normalized to have values constrained between -1 and 1. No data augmentation was used. Results where sampled every 10,000 meta-training steps and experiments took between 50,000 and 250,000 meta-training steps for results to converge. All experiments were run on a single Tesla V100 on Google Cloud Platform (GCP). Training a model for 250,000 meta-training steps with n=4 on Omniglot took 125 hours with this setup. Table 1 at the end of this paper shows hyperparameters for all experiments.

5.2 Empirical Validation

In contrast with prior work, our model works on grayscale images rather than binary images. Our model also works without an external memory, a lengthy sequential inference process or additional training data in the form of pen stroke information. We believe that our approach, being built on top of GANs, has the best capacity to generalize to more challenging problems.

Shown below are the results of generating unseen test classes on our three datasets. The first row of every figure that follows represents the training data (circled in red). The following three rows are images generated by the model fine-tuned on those data points for 10 gradient steps. All images present results on previously unseen test classes. If unspecified, n=4.

MNIST The MNIST data was rescaled to 32x32 pixel. The training classes are the digits from 0 to 8. The test class is the digit 9.

Figure 3: MNIST; 50,000 update; 10 gradient steps

On Figure 3, we can see good results on MNIST after 50,000 meta-training steps. This validates our approach on a toy problem.

Omniglot The Omniglot data was resized to 32×32 and 64×64. The training classes where all 1623 characters in the dataset minus 20 randomly sampled character classes for the test set.

Figure 4: Omniglot; 140,000 update; 10 gradient steps
Figure 5: Omniglot; 230,000 update; 10 gradient steps

On simpler Omniglot characters like the one shown in Figure 4, the model converges to good results after 140,000 meta-training steps. On more complex characters, even after 230,000 meta-training steps results are still lacking and humans can easily distinguish between most generated characters and the real ones. This is pictured in Figure 5.

As for the 64×64 images, a batch size of 8 was required to generate good results. In this case, after 150,000 meta-training steps, around half the generated characters could conceivably fool a human judge. This is pictured in Figure 6.

FIGR-8 The FIGR-8 data was resized to 32x32 pixels. The training classes where all 18,409 classes minus 50 randomly sampled classes for the test set. Here, n=8 was used for all experiments.

For the FIGR-8 dataset, arguably none of the generated images pictured in Figures 7, 8 and 9 can fool a human. We however see our model able to learn key features of the images very quickly, such as a birdlike shape or an ice cream cone.

Figure 6: Omniglot; 150,000 update; 10 gradient steps; 64×64; n=8
Figure 7: FIGR-8; 80,000 update; 10 gradient steps; n=8
Figure 8: FIGR-8; 90,000 update; 10 gradient steps; n=8
Figure 9: FIGR-8; 100,000 update; 10 gradient steps; n=8

6 Conclusion

We have shown that Reptile can be used to effectively train Generative Adversarial Networks for few-shot image generation. Using meta-training on a dataset containing several similar classes of images, we can learn to generate images from an unseen class with as little as 4 samples on MNIST and Omniglot datasets. This is done with no lengthy inference time, no external memory and no additional data. No hyperparameter tuning is required, the base parameters used are stable troughout experiments. It is, to our knowledge, the first GAN trained for few-shot image generation. Results show that our approach is able to quickly learn and generate simple concepts as well as complex ones. Preliminary results on FIGR-8 show that a complex concept such as “bird” can be learned. To date, no other few-shot image generation model has managed to generate images other than handwritten characters. The low amount of data required to generate images, once the model is pretrained, opens the door to several applications that were previously gated by the high amount of data required.

We have also built, and will release for open source use, FIGR-8, a dataset containing over 18,409 different classes and over 1,548,944 images. Hopefully, this dataset will become a strong benchmark in the task of few-shot image generation.

Several future directions should be explored:

  • Generating multi-channel and/or larger images, such as with the CIFAR-100 dataset or the ImageNet dataset.

  • Modifying batch normalization layers to be able to meta-train through them.

  • Exploiting the wide variety of GAN architectures available.

  • Using FIGR on ImageNet to make a pretrained GAN model for fine-tuning and transfer learning in the same capacity that ImageNet models are used for fine-tuning computer-vision models.

The code for the FIGR implementation can be found at https://github.com/OctThe16th/FIGR and the FIGR-8 database can be found at https://github.com/marcdemers/FIGR-8 and bit.ly/FIGR-8.

References

  • [1] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In Doina Precup and Yee Whye Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 214–223, International Convention Centre, Sydney, Australia, 06–11 Aug 2017. PMLR.
  • [2] Lei Jimmy Ba, Ryan Kiros, and Geoffrey E. Hinton. Layer normalization. CoRR, abs/1607.06450, 2016.
  • [3] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. arXiv e-prints, abs/1409.0473, September 2014.
  • [4] Sergey Bartunov and Dmitry P. Vetrov. Few-shot generative modelling with generative matching networks. In AISTATS, 2018.
  • [5] Léon Bottou. Large-scale machine learning with stochastic gradient descent. In Yves Lechevallier and Gilbert Saporta, editors, Proceedings of COMPSTAT’2010, pages 177–186, Heidelberg, 2010. Physica-Verlag HD.
  • [6] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. CoRR, abs/1703.03400, 2017.
  • [7] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Z. Ghahramani, M. Welling, C. Cortes, N. D. Lawrence, and K. Q. Weinberger, editors, Advances in Neural Information Processing Systems 27, pages 2672–2680. Curran Associates, Inc., 2014.
  • [8] Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, and Aaron Courville. Improved training of wasserstein gans, 2017.
  • [9] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. 2015 IEEE International Conference on Computer Vision (ICCV), Dec 2015.
  • [10] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Jun 2016.
  • [11] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization, 2014.
  • [12] Diederik P. Kingma and Max Welling. Auto-encoding variational bayes. CoRR, abs/1312.6114, 2013.
  • [13] Brenden M. Lake, Ruslan Salakhutdinov, Jason Gross, and Joshua B. Tenenbaum. One shot learning of simple visual concepts.
  • [14] Yann LeCun and Corinna Cortes. MNIST handwritten digit database. 2010.
  • [15] Alex Nichol, Joshua Achiam, and John Schulman. On first-order meta-learning algorithms, 2018.
  • [16] Alec Radford, Luke Metz, and Soumith Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks, 2015.
  • [17] Danilo Rezende, Shakir, Ivo Danihelka, Karol Gregor, and Daan Wierstra. One-shot generalization in deep generative models. In Maria Florina Balcan and Kilian Q. Weinberger, editors, Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pages 1521–1529, New York, New York, USA, 20–22 Jun 2016. PMLR.
  • [18] Oriol Vinyals, Charles Blundell, Tim Lillicrap, koray kavukcuoglu, and Daan Wierstra. Matching networks for one shot learning. In D. D. Lee, M. Sugiyama, U. V. Luxburg, I. Guyon, and R. Garnett, editors, Advances in Neural Information Processing Systems 29, pages 3630–3638. Curran Associates, Inc., 2016.
MNIST Omniglot FIGR-8
Inner learning rate 0.0001 0.0001 0.0001
Outer learning rate 0.00001 0.00001 0.00001
Training size n 4 4 and 8 8
Inner loops K 10 10 10
Image resize 32×32 32×32 and 64×64 32×32
Grayscale True True True
Validation classes 1 20 50
Table 1: Hyperparameters for all experiments