Conditional Tabular GAN - CTGAN
One of the most interesting ideas on the last decades in machine learning is the GAN architecture for generatine model. While GANs have shown remarkable success in generating high-quality images and other continuous data types, tabular data poses unique challenges. Tabular data often contains a mix of discrete and continuous variables, with complex dependencies between them. Traditional GAN architectures struggle to capture these intricacies, leading to poor performance when applied directly to tabular data. CTGAN addresses these challenges by introducing a specialized framework for generating realistic synthetic tabular data.
CTGAN Architecture
The CTGAN model is based on the GAN framework, comprising two main components: a generator and a discriminator.
- Generator: The generator is responsible for producing synthetic data that resembles the real data. It takes as input a noise vector sampled from a standard normal distribution and a conditioning vector. The conditioning vector encodes specific values of categorical variables, enabling the generation of samples with particular characteristics. The generator outputs a vector representing a synthetic data sample, which includes both categorical and continuous features.
- Discriminator: The discriminator’s role is to distinguish between real and synthetic data samples. It is a binary classifier trained to maximize the likelihood of correctly identifying real data from the generator’s output. The discriminator receives both real and synthetic data as input and outputs a probability score indicating whether the input is real or fake.
Key Innovations in CTGAN
CTGAN introduces several key innovations that enable it to handle the complexities of tabular data:
-
Conditional Vector and Mode-specific Normalization: This allows the generator to conditionally generate data based on specific categories. This is particularly useful in scenarios where some features are strongly dependent on others. For example, in a dataset with gender and age features, conditioning on gender can help the generator produce age values that are realistic for that gender. Mode-specific normalization is applied to continuous features during training, ensuring that each mode (category) is normalized independently. This prevents the generator from learning spurious correlations between categorical and continuous features.
-
Training by Sampling: To address the imbalance in categorical data, CTGAN uses a sampling strategy where the generator focuses on underrepresented categories during training. This ensures that the generator learns to produce realistic samples across all categories, not just the majority ones.
-
Logistic Loss for the Generator: The generator’s loss function is modified to include a logistic loss term. This encourages the generator to produce samples that are close to the decision boundary of the discriminator, promoting diversity in the generated samples. The inclusion of this term helps the generator avoid mode collapse, a common issue in GANs where the generator produces a limited variety of samples.
-
Training Process. The training process for CTGAN follows the standard GAN training procedure.
-
Preprocessing: The data is preprocessed to convert categorical variables into one-hot encoded vectors. Continuous variables are normalized using mode-specific normalization. The data is then split into real samples and conditional vectors.