跳转到主要内容

用于合成表格数据的生成对抗训练

项目描述

DAI-Lab 来自MIT Data to AI Lab的开源项目。

PyPi Shield Travis CI Shield

TGAN

用于合成表格数据的生成对抗训练。

TGAN是一个表格数据合成器。它可以从真实数据生成完全合成的数据。目前,TGAN可以生成数值列和分类列。

需求

Python

TGAN 已开发并在Python 3.53.63.7 上运行。

此外,尽管这不是严格要求的,但为了防止与运行TGAN的系统中的其他软件发生干扰,强烈推荐使用virtualenv

安装

安装TGAN最简单和推荐的方式是使用pip

pip install tgan

或者,您也可以克隆仓库并从源代码安装它

git clone git@github.com:DAI-Lab/TGAN.git
cd TGAN
make install

对于开发,您可以使用make install-develop来安装所有必要的测试和代码检查依赖项

数据格式

输入格式

为了能够采样新的合成数据,TGAN首先需要被拟合到现有数据。

拟合过程的输入数据必须是一个满足以下规则的单一表格

  • 没有缺失值。
  • intfloatstrbool类型的列。
  • 每一列只包含一种类型的数据。

这样的表格的例子可能是

str_column float_column int_column bool_column
'green' 0.15 10 True
'blue' 7.25 23 False
'red' 10.00 1 False
'yellow' 5.50 17 True

如您所见,此表格包含4列:str_columnfloat_columnint_columnbool_column,每一列都是受支持的值类型的示例。注意,任何行都没有缺失值。

注意:正确识别哪些列是数值的非常重要,这意味着它们代表一个量级,哪些是分类的,因为在数据的预处理过程中,数值和分类列将被不同地处理。

输出格式

TGAN的输出是与输入表格具有相同列的采样数据表,行数与请求的数量相同。

演示数据集

TGAN包含一些数据集,可用于开发和演示目的。这些数据集来自UCI机器学习仓库,并且已经预处理以符合输入格式部分中指定的要求。

这些数据集可以在hdi-project-tgan AWS S3存储桶中浏览和直接下载

人口普查数据集

此数据集包含一个表格,包含人口普查信息,并标记了收入是否超过每年50,000美元。它是一个包含199522行和41列的单个csv文件。在这些41列中,只有7列被标识为连续的。在TGAN中,此数据集称为census

覆盖类型

此数据集包含一个带有不同森林覆盖类型标记的单一表格。它是一个包含465588行和55列的单个csv文件。在这些55列中,有10列被标识为连续的。在TGAN中,此数据集称为covertype

快速入门

在本简短教程中,我们将引导您完成一系列步骤,帮助您开始使用TGAN的基本用法,以便从给定数据集中生成样本。

注意:以下示例也包含在一个Jupyter笔记本中,您可以在virtualenv中运行以下命令来执行它

pip install jupyter
jupyter notebook examples/Usage_Example.ipynb

1. 加载数据

第一步是加载数据,这些数据将用于拟合TGAN。为此,我们首先导入函数tgan.data.load_data,并用要加载数据集的名称调用它。

在这种情况下,我们将加载census数据集,我们将在后续步骤中使用它,并获取两个对象

  1. data,它将包含一个pandas.DataFrame,其中包含来自census数据集的表格数据,准备用于拟合模型。

  2. continuous_columns,它将包含一个包含连续列索引的list

>>> from tgan.data import load_demo_data
>>> data, continuous_columns = load_demo_data('census')
>>> data.head(3).T[:10]
                              0                                     1                             2
0                            73                                    58                            18
1               Not in universe        Self-employed-not incorporated               Not in universe
2                             0                                     4                             0
3                             0                                    34                             0
4          High school graduate            Some college but no degree                    10th grade
5                             0                                     0                             0
6               Not in universe                       Not in universe                   High school
7                       Widowed                              Divorced                 Never married
8   Not in universe or children                          Construction   Not in universe or children
9               Not in universe   Precision production craft & repair               Not in universe

>>> continuous_columns
[0, 5, 16, 17, 18, 29, 38]

2. 创建TGAN实例

下一步是导入TGAN并创建模型实例。

为此,我们需要导入tgan.model.TGANModel类,并用continuous_columns作为唯一参数调用它。

这将创建一个具有默认参数的TGAN实例

>>> from tgan.model import TGANModel
>>> tgan = TGANModel(continuous_columns)

3. 拟合模型

一旦你有了一个TGAN实例,你可以调用它的fit方法,传入你之前加载的data,以开始拟合过程

>>> tgan.fit(data)

这个过程不会返回任何东西,但是拟合的进度将被打印在屏幕上。

注意 根据你运行系统的性能和选择的模型参数,这一步可能需要几个小时。

4. 生成新数据

一旦模型被拟合,你就可以通过调用TGAN实例的sample方法并传入所需的样本数量来生成新的样本

>>> num_samples = 1000
>>> samples = tgan.sample(num_samples)
>>> samples.head(3).T[:10]
                                         0                                     1                                   2
0                                       12                                    27                                  56

0                                       12                                    27                                  56
1                          Not in universe        Self-employed-not incorporated                             Private
2                                        0                                     4                                  35
3                                        0                                    34                                  22
4                                 Children            Some college but no degree          Some college but no degree
5                                        0                                     0                                 500
6                          Not in universe                       Not in universe                     Not in universe
7                            Never married       Married-civilian spouse present     Married-civilian spouse present
8              Not in universe or children                          Construction   Finance insurance and real estate
9                          Not in universe   Precision production craft & repair      Adm support including clerical

返回的对象samples是一个包含具有与输入数据相同格式的合成数据表格的pandas.DataFrame,包含我们请求的1000行。

5. 保存和加载模型

在上面的步骤中,我们看到了拟合过程可能需要很长时间,所以我们可能不想每次生成样本时都要拟合。相反,我们可以拟合一次模型,保存它,然后每次我们需要采样新数据时都加载它。

如果我们有一个拟合的模型,我们可以通过调用它的save方法来保存它,该方法只接受一个参数,即模型将被存储的路径。同样,TGANModel.load允许通过传入存储模型的路径来加载存储在磁盘上的模型。

>>> model_path = 'models/mymodel.pkl'
>>> tgan.save(model_path)
Model saved successfully.

请注意,如果文件已经存在,TGAN将避免覆盖它,除非传递了force=True参数。

>>> tgan.save(model_path)
The indicated path already exists. Use `force=True` to overwrite.

为此

>>> tgan.save(model_path, force=True)
Model saved successfully.

一旦模型被保存,就可以通过使用TGANModel.load方法将其作为TGAN实例重新加载

>>> new_tgan = TGANModel.load(model_path)
>>> new_samples = new_tgan.sample(num_samples)
>>> new_samples.head(3).T[:10]

                                         0                                     1                                   2
0                                       12                                    27                                  56

0                                       12                                    27                                  56
1                          Not in universe        Self-employed-not incorporated                             Private
2                                        0                                     4                                  35
3                                        0                                    34                                  22
4                                 Children            Some college but no degree          Some college but no degree
5                                        0                                     0                                 500
6                          Not in universe                       Not in universe                     Not in universe
7                            Never married       Married-civilian spouse present     Married-civilian spouse present
8              Not in universe or children                          Construction   Finance insurance and real estate
9                          Not in universe   Precision production craft & repair      Adm support including clerical

此时,我们可以使用此模型实例来生成更多样本。

加载自定义数据集

在上面的步骤中,我们使用了一些演示数据,但我们没有向你展示如何加载你自己的数据集。

为此,你需要从你的数据集中生成一个pandas.DataFrame对象。如果你的数据集是csv格式,你可以通过使用pandas.read_csv并传入你想要加载的CSV文件的路径来实现。

此外,你需要创建一个0索引的列索引列表,以考虑为连续。

例如,如果我们想加载一个本地CSV文件path/to/my.csv,其中连续列是前4列,即索引[0, 1, 2, 3],我们会这样做

>>> import pandas as pd
>>> data = pd.read_csv('data/census.csv')
>>> continuous_columns = [0, 1, 2, 3]

现在你可以使用continuous_columns来创建一个TGAN实例,并使用datafit它,就像我们之前做的那样

>>> from tgan.model import TGANModel
>>> tgan = TGANModel(continuous_columns)
>>> tgan.fit(data)

模型参数

如果你想要改变TGANModel的默认行为,例如不同的batch_sizenum_epochs,你可以在创建实例时传递不同的参数。

模型一般行为

  • continous_columns (list[int], required): 被视为连续的列索引列表。
  • output (str, default=output): 存储模型及其组件的路径。

神经网络定义和拟合

  • max_epoch (int, default=100): 训练期间使用的epoch数量。
  • steps_per_epoch (int, 默认=10000):每个epoch运行的步数。
  • save_checkpoints(bool, 默认=True):是否在每个训练epoch后存储模型检查点。
  • restore_session(bool, 默认=True):是否从最后一个检查点继续训练。
  • batch_size (int, 默认=200):每次步中馈送到模型的批次大小。
  • z_dim (int, 默认=100):生成器噪声输入中的维度数量。
  • noise (float, 默认=0.2):添加到分类列的高斯噪声的上限。
  • l2norm (float, 默认=0.00001):计算损失时的L2正则化系数。
  • learning_rate (float, 默认=0.001):优化器的学习率。
  • num_gen_rnn (int, 默认=400):生成器中RNN单元的数量。
  • num_gen_feature (int, 默认=100):生成器中全连接层中的单元数量。
  • num_dis_layers (int, 默认=2):判别器中的层数。
  • num_dis_hidden (int, 默认=200):判别器中每层的单元数量。
  • optimizer (str, 默认=AdamOptimizer):在fit过程中使用的优化器名称。可能的值有:[GradientDescentOptimizer, AdamOptimizer, AdadeltaOptimizer]。

如果您想创建与步骤2中创建的实例相同的实例,但以显式方式传递参数,可以使用以下行实现

>>> from tgan.model import TGANModel
>>> tgan = TGANModel(
   ...:     continuous_columns,
   ...:     output='output',
   ...:     max_epoch=5,
   ...:     steps_per_epoch=10000,
   ...:     save_checkpoints=True,
   ...:     restore_session=True,
   ...:     batch_size=200,
   ...:     z_dim=200,
   ...:     noise=0.2,
   ...:     l2norm=0.00001,
   ...:     learning_rate=0.001,
   ...:     num_gen_rnn=100,
   ...:     num_gen_feature=100,
   ...:     num_dis_layers=1,
   ...:     num_dis_hidden=100,
   ...:     optimizer='AdamOptimizer'
   ...: )

命令行界面

我们包括一个命令行界面,允许用户访问TGAN功能。目前只支持一个操作。

随机超参数搜索

输入

要为给定数据集运行最佳模型超参数的随机搜索,我们需要以下内容

  • 一个数据集,在csv文件中,没有缺失值,只有类型为boolstrintfloat的列,并且每列只有一个类型,如输入格式中所述。

  • 包含搜索配置的JSON文件。此配置应包含

    • name:实验的名称。将创建一个具有此名称的文件夹。
    • num_random_search:超参数搜索中的迭代次数。
    • train_csv:包含数据集的csv文件的路径。
    • continuous_cols:要考虑为连续的列索引列表,从0开始。
    • epoch:训练模型时的epoch数。
    • steps_per_epoch:每个epoch中的优化步数。
    • sample_rows:评估模型时要采样的行数。

您可以在examples/config.json中看到一个这样的json文件的示例,您可以下载并用作模板。

执行

一旦我们准备好了所有东西,我们可以使用以下命令启动随机超参数搜索

tgan experiments config.json results.json

其中第一个参数config.json是配置JSON的路径,第二个参数results.json是存储执行摘要的路径。

这将运行随机搜索,它基本上包括以下步骤

  1. 我们获取并划分我们的数据以供测试和训练使用。
  2. 我们随机选择要测试的超参数。
  3. 然后,对于每个超参数组合,我们使用真实的训练数据T训练一个TGAN模型并生成一个合成的训练数据集Tsynth。
  4. 然后,我们在真实和合成的数据集上训练机器学习模型。
  5. 我们使用这些训练好的模型在真实测试数据上运行,并查看它们的性能。

输出

实验完成后,以下内容可以得到

  • 一个JSON文件,在上述示例中称为results.json,包含实验的摘要。此JSON将包含每个实验的键name,并且在该键上,将包含长度为num_random_search的数组,其中包含选定的参数及其评估分数。对于像示例这样的配置,摘要将如下所示
{
    'census': [
        {
            "steps_per_epoch" : 10000,
            "num_gen_feature" : 300,
            "num_dis_hidden" : 300,
            "batch_size" : 100,
            "num_gen_rnn" : 400,
            "score" : 0.937802280415988,
            "max_epoch" : 5,
            "num_dis_layers" : 4,
            "learning_rate" : 0.0002,
            "z_dim" : 100,
            "noise" : 0.2
        },
        ... # 9 more nodes
    ]
}
  • 一组文件夹,每个文件夹的名称都按照JSON配置文件中指定的 name 命名,包含在 experiments 文件夹中。在每个文件夹中,可以找到采样数据和模型。对于一个像示例中的配置,它将看起来像这样
experiments/
  census/
    data/       # Sampled data with each of the models in the random search.
    model_0/
      logs/     # Training logs
      model/    # Tensorflow model checkpoints
    model_1/    # 9 more folders, one for each model in the random search
    ...

研究

第一个 TAGN 版本作为 Lei Xu 和 Kalyan Veeramachaneni 所著的 《使用生成对抗网络合成表格数据》 研究论文的支持软件构建。

论文中提到的软件的精确版本可以在发布部分以 研究预发布 的形式找到。

引用

如果您使用了 TGAN,请引用以下工作

Lei Xu, Kalyan Veeramachaneni. 2018. 使用生成对抗网络合成表格数据。

@article{xu2018synthesizing,
  title={Synthesizing Tabular Data using Generative Adversarial Networks},
  author={Xu, Lei and Veeramachaneni, Kalyan},
  journal={arXiv preprint arXiv:1811.11264},
  year={2018}
}

您可以在这里找到原始论文。

历史

0.1.0

  • 首次发布在 PyPI 上。

项目详情


下载文件

下载适用于您平台的应用程序文件。如果您不确定选择哪个,请了解更多关于 安装软件包 的信息。

源分发

tgan-0.1.0.tar.gz (83.9 kB 查看散列)

上传时间

构建分发

tgan-0.1.0-py2.py3-none-any.whl (27.2 kB 查看散列)

上传时间 Python 2 Python 3

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误日志 StatusPage StatusPage 状态页面