ECOGAN is Generative Adversarial Networks(GANs) for generating imbalanced data. A similarity-based distance learning method is applied for imbalance data learning.
Imbalance data refers to data in which the elements (class, object, scale, and etc.) constituting the data are not constant. In this experiment, we learn data consisting of a long tail distribution with an inconsistent number of category information, as shown in Figure (b).
The following figure is a schematic diagram of discriminator previously proposed for imbalance data learning. BAGAN(a) first pointed out the problems that arise when learning imbalance data through generative models, and proposed a pre-learning method using autoencoder for the first time. Unlike BAGAN, IDA-GAN(b) used a pre-learning method through a variational autoencoder, and proposed a method of learning by dividing the existing one output into two to alleviate the learning contradiction between the generator and the discriminator. EBGAN(c) allows the learning of class information in the pre-learning process by multiplying the latent space with embeddings of class information. Finally, ours(d) proposes a novel structure to enable the application of cosine similarity-based contrast learning methods for imbalance data learning.
It is a schematic diagram of the learning process of previously proposed Metric learning methods and methods used in a conditional generation model. Our method (f) uses information between all data within batch data for learning, unlike proposed methods, to improve the learning imbalance problem of minority class data.
Experiments were conducted in three aspects to compare performance.
- Experiments for performance comparison with existing metric learning methods
- Experiments to determine why hinge loss-based loss functions are difficult to learn imbalance data
- Experiments for performance comparison with existing pre-learning methods
For the existing proposed metric learning methods, experiments were conducted with balanced data because they were proposed in a balanced data environment. We also confirm our results with imbalance data to confirm that our proposed method is more useful for imbalance data learning than existing metric learning methods.
The figure below is a visualization of the evaluation metric (FID, IS) measured in the generator learning process. When the above two rows learn balanced data, the following two rows are the results of learning imbalance data. For our method, we can confirm similar or better performance than conventional metric learning. In particular, we show similar performance to the D2D-CE loss function, which improves the misclassification problem that appears in the existing metric learning problem, which can be confirmed to be robust to the misclassification problem, unlike the existing metric learning method. On the other hand, in the case of learning imbalance data, it was confirmed that the performance of existing metric learning methods was no longer improved by mode collapse. This confirms that our method is robust to misclassification problems, especially in imbalance data learning, and that learning problems such as mode collapse do not appear.
Method | Data | FID(↓) | IS score(↑) |
---|---|---|---|
2C[20] | balance | 6.63 | 9.22 |
D2D-CE[27] | balance | 4.71 | 9.76 |
ECO(ours) | balance | 4.88 | 9.77 |
2C[20] | imbalance | 29.04 | 6.15 |
D2D-CE[27] | imbalance | 42.65 | 5.74 |
ECO(Ours) | imbalance | 25.53 | 6.56 |
The following figure is a visualization of the evaluation indicators measured in the process of learning neural networks of different sizes through the D2D-CE loss function. D2D-CE is an application of hinge loss, which focuses on data learning that is difficult to classify errors in easily classifiable data through methods that do not reflect them in learning. However, in learning unbalanced data, it can be analyzed that mode decay occurs early in learning because minority class data have fewer absolute numbers of learning data, so that the generator targets the unlearned portion of the discriminator before learning the minority class data accurately.
Model | Data | Best step | FID(↓) | IS score(↑) | Pre-trained | Sampling |
---|---|---|---|---|---|---|
BAGAN[10] | FashionMNIST_LT | 64000 | 92.61 | 2.81 | TRUE | - |
EBGAN[12] | FashionMNIST_LT | 120000 | 27.40 | 2.43 | TRUE | - |
EBGAN[12] | FashionMNIST_LT | 150000 | 30.10 | 2.38 | FALSE | - |
ECOGAN(ours) | FashionMNIST_LT | 126000 | 32.91 | 2.91 | - | FALSE |
ECOGAN(ours) | FashionMNIST_LT | 120000 | 20.02 | 2.63 | - | TRUE |
BAGAN[10] | CIFAR10_LT | 76000 | 125.77 | 2.14 | TRUE | - |
EBGAN[12] | CIFAR10_LT | 144000 | 60.11 | 2.36 | TRUE | - |
EBGAN[12] | CIFAR10_LT | 150000 | 68.90 | 2.29 | FALSE | - |
ECOGAN(ours) | CIFAR10_LT | 144000 | 51.71 | 2.83 | - | FALSE |
ECOGAN(ours) | CIFAR10_LT | 138000 | 43.79 | 2.74 | - | TRUE |
EBGAN[12] | Places_LT | 150000 | 136.92 | 2.57 | FALSE | - |
EBGAN[12] | Places_LT | 144000 | 144.04 | 2.46 | TRUE | - |
ECOGAN(ours) | Places_LT | 105000 | 91.55 | 3.02 | - | FALSE |
ECOGAN(ours) | Places_LT | 75000 | 95.43 | 3.01 | - | TRUE |
Modifying code
data
└── CIFAR10_LT, FashionMNIST_LT or Places_LT
├── train
│ ├── cls0
│ │ ├── train0.png
│ │ ├── train1.png
│ │ └── ...
│ ├── cls1
│ └── ...
└── valid
├── cls0
│ ├── valid0.png
│ ├── valid1.png
│ └── ...
├── cls1
└── ...
Modifying code