用于合成表格数据的生成对抗训练
项目描述
来自MIT Data to AI Lab的开源项目。
TGAN
用于合成表格数据的生成对抗训练。
TGAN是一个表格数据合成器。它可以从真实数据生成完全合成的数据。目前,TGAN可以生成数值列和分类列。
- 免费软件:MIT许可证
- 文档:https://DAI-Lab.github.io/TGAN
- 主页:https://github.com/DAI-Lab/TGAN
需求
Python
TGAN 已开发并在Python 3.5,3.6 和 3.7 上运行。
此外,尽管这不是严格要求的,但为了防止与运行TGAN的系统中的其他软件发生干扰,强烈推荐使用virtualenv。
安装
安装TGAN最简单和推荐的方式是使用pip
pip install tgan
或者,您也可以克隆仓库并从源代码安装它
git clone git@github.com:DAI-Lab/TGAN.git
cd TGAN
make install
对于开发,您可以使用make install-develop
来安装所有必要的测试和代码检查依赖项
数据格式
输入格式
为了能够采样新的合成数据,TGAN首先需要被拟合到现有数据。
此拟合过程的输入数据必须是一个满足以下规则的单一表格
- 没有缺失值。
- 有
int
、float
、str
或bool
类型的列。 - 每一列只包含一种类型的数据。
这样的表格的例子可能是
str_column | float_column | int_column | bool_column |
---|---|---|---|
'green' | 0.15 | 10 | True |
'blue' | 7.25 | 23 | False |
'red' | 10.00 | 1 | False |
'yellow' | 5.50 | 17 | True |
如您所见,此表格包含4列:str_column
、float_column
、int_column
和bool_column
,每一列都是受支持的值类型的示例。注意,任何行都没有缺失值。
注意:正确识别哪些列是数值的非常重要,这意味着它们代表一个量级,哪些是分类的,因为在数据的预处理过程中,数值和分类列将被不同地处理。
输出格式
TGAN的输出是与输入表格具有相同列的采样数据表,行数与请求的数量相同。
演示数据集
TGAN包含一些数据集,可用于开发和演示目的。这些数据集来自UCI机器学习仓库,并且已经预处理以符合输入格式部分中指定的要求。
这些数据集可以在hdi-project-tgan AWS S3存储桶中浏览和直接下载
人口普查数据集
此数据集包含一个表格,包含人口普查信息,并标记了收入是否超过每年50,000美元。它是一个包含199522行和41列的单个csv文件。在这些41列中,只有7列被标识为连续的。在TGAN中,此数据集称为census
。
覆盖类型
此数据集包含一个带有不同森林覆盖类型标记的单一表格。它是一个包含465588行和55列的单个csv文件。在这些55列中,有10列被标识为连续的。在TGAN中,此数据集称为covertype
。
快速入门
在本简短教程中,我们将引导您完成一系列步骤,帮助您开始使用TGAN的基本用法,以便从给定数据集中生成样本。
注意:以下示例也包含在一个Jupyter笔记本中,您可以在virtualenv中运行以下命令来执行它
pip install jupyter
jupyter notebook examples/Usage_Example.ipynb
1. 加载数据
第一步是加载数据,这些数据将用于拟合TGAN。为此,我们首先导入函数tgan.data.load_data
,并用要加载数据集的名称调用它。
在这种情况下,我们将加载census
数据集,我们将在后续步骤中使用它,并获取两个对象
-
data
,它将包含一个pandas.DataFrame
,其中包含来自census
数据集的表格数据,准备用于拟合模型。 -
continuous_columns
,它将包含一个包含连续列索引的list
。
>>> from tgan.data import load_demo_data
>>> data, continuous_columns = load_demo_data('census')
>>> data.head(3).T[:10]
0 1 2
0 73 58 18
1 Not in universe Self-employed-not incorporated Not in universe
2 0 4 0
3 0 34 0
4 High school graduate Some college but no degree 10th grade
5 0 0 0
6 Not in universe Not in universe High school
7 Widowed Divorced Never married
8 Not in universe or children Construction Not in universe or children
9 Not in universe Precision production craft & repair Not in universe
>>> continuous_columns
[0, 5, 16, 17, 18, 29, 38]
2. 创建TGAN实例
下一步是导入TGAN并创建模型实例。
为此,我们需要导入tgan.model.TGANModel
类,并用continuous_columns
作为唯一参数调用它。
这将创建一个具有默认参数的TGAN实例
>>> from tgan.model import TGANModel
>>> tgan = TGANModel(continuous_columns)
3. 拟合模型
一旦你有了一个TGAN
实例,你可以调用它的fit
方法,传入你之前加载的data
,以开始拟合过程
>>> tgan.fit(data)
这个过程不会返回任何东西,但是拟合的进度将被打印在屏幕上。
注意 根据你运行系统的性能和选择的模型参数,这一步可能需要几个小时。
4. 生成新数据
一旦模型被拟合,你就可以通过调用TGAN
实例的sample
方法并传入所需的样本数量来生成新的样本
>>> num_samples = 1000
>>> samples = tgan.sample(num_samples)
>>> samples.head(3).T[:10]
0 1 2
0 12 27 56
0 12 27 56
1 Not in universe Self-employed-not incorporated Private
2 0 4 35
3 0 34 22
4 Children Some college but no degree Some college but no degree
5 0 0 500
6 Not in universe Not in universe Not in universe
7 Never married Married-civilian spouse present Married-civilian spouse present
8 Not in universe or children Construction Finance insurance and real estate
9 Not in universe Precision production craft & repair Adm support including clerical
返回的对象samples
是一个包含具有与输入数据相同格式的合成数据表格的pandas.DataFrame
,包含我们请求的1000行。
5. 保存和加载模型
在上面的步骤中,我们看到了拟合过程可能需要很长时间,所以我们可能不想每次生成样本时都要拟合。相反,我们可以拟合一次模型,保存它,然后每次我们需要采样新数据时都加载它。
如果我们有一个拟合的模型,我们可以通过调用它的save
方法来保存它,该方法只接受一个参数,即模型将被存储的路径。同样,TGANModel.load
允许通过传入存储模型的路径来加载存储在磁盘上的模型。
>>> model_path = 'models/mymodel.pkl'
>>> tgan.save(model_path)
Model saved successfully.
请注意,如果文件已经存在,TGAN将避免覆盖它,除非传递了force=True
参数。
>>> tgan.save(model_path)
The indicated path already exists. Use `force=True` to overwrite.
为此
>>> tgan.save(model_path, force=True)
Model saved successfully.
一旦模型被保存,就可以通过使用TGANModel.load
方法将其作为TGAN
实例重新加载
>>> new_tgan = TGANModel.load(model_path)
>>> new_samples = new_tgan.sample(num_samples)
>>> new_samples.head(3).T[:10]
0 1 2
0 12 27 56
0 12 27 56
1 Not in universe Self-employed-not incorporated Private
2 0 4 35
3 0 34 22
4 Children Some college but no degree Some college but no degree
5 0 0 500
6 Not in universe Not in universe Not in universe
7 Never married Married-civilian spouse present Married-civilian spouse present
8 Not in universe or children Construction Finance insurance and real estate
9 Not in universe Precision production craft & repair Adm support including clerical
此时,我们可以使用此模型实例来生成更多样本。
加载自定义数据集
在上面的步骤中,我们使用了一些演示数据,但我们没有向你展示如何加载你自己的数据集。
为此,你需要从你的数据集中生成一个pandas.DataFrame
对象。如果你的数据集是csv
格式,你可以通过使用pandas.read_csv
并传入你想要加载的CSV文件的路径来实现。
此外,你需要创建一个0索引的列索引列表,以考虑为连续。
例如,如果我们想加载一个本地CSV文件path/to/my.csv
,其中连续列是前4列,即索引[0, 1, 2, 3]
,我们会这样做
>>> import pandas as pd
>>> data = pd.read_csv('data/census.csv')
>>> continuous_columns = [0, 1, 2, 3]
现在你可以使用continuous_columns
来创建一个TGAN
实例,并使用data
来fit
它,就像我们之前做的那样
>>> from tgan.model import TGANModel
>>> tgan = TGANModel(continuous_columns)
>>> tgan.fit(data)
模型参数
如果你想要改变TGANModel
的默认行为,例如不同的batch_size
或num_epochs
,你可以在创建实例时传递不同的参数。
模型一般行为
- continous_columns (
list[int]
, required): 被视为连续的列索引列表。 - output (
str
, default=output
): 存储模型及其组件的路径。
神经网络定义和拟合
- max_epoch (
int
, default=100
): 训练期间使用的epoch数量。 - steps_per_epoch (
int
, 默认=10000
):每个epoch运行的步数。 - save_checkpoints(
bool
, 默认=True):是否在每个训练epoch后存储模型检查点。 - restore_session(
bool
, 默认=True):是否从最后一个检查点继续训练。 - batch_size (
int
, 默认=200
):每次步中馈送到模型的批次大小。 - z_dim (
int
, 默认=100
):生成器噪声输入中的维度数量。 - noise (
float
, 默认=0.2
):添加到分类列的高斯噪声的上限。 - l2norm (
float
, 默认=0.00001
):计算损失时的L2正则化系数。 - learning_rate (
float
, 默认=0.001
):优化器的学习率。 - num_gen_rnn (
int
, 默认=400
):生成器中RNN单元的数量。 - num_gen_feature (
int
, 默认=100
):生成器中全连接层中的单元数量。 - num_dis_layers (
int
, 默认=2
):判别器中的层数。 - num_dis_hidden (
int
, 默认=200
):判别器中每层的单元数量。 - optimizer (
str
, 默认=AdamOptimizer
):在fit
过程中使用的优化器名称。可能的值有:[GradientDescentOptimizer
,AdamOptimizer
,AdadeltaOptimizer
]。
如果您想创建与步骤2中创建的实例相同的实例,但以显式方式传递参数,可以使用以下行实现
>>> from tgan.model import TGANModel
>>> tgan = TGANModel(
...: continuous_columns,
...: output='output',
...: max_epoch=5,
...: steps_per_epoch=10000,
...: save_checkpoints=True,
...: restore_session=True,
...: batch_size=200,
...: z_dim=200,
...: noise=0.2,
...: l2norm=0.00001,
...: learning_rate=0.001,
...: num_gen_rnn=100,
...: num_gen_feature=100,
...: num_dis_layers=1,
...: num_dis_hidden=100,
...: optimizer='AdamOptimizer'
...: )
命令行界面
我们包括一个命令行界面,允许用户访问TGAN功能。目前只支持一个操作。
随机超参数搜索
输入
要为给定数据集运行最佳模型超参数的随机搜索,我们需要以下内容
-
一个数据集,在csv文件中,没有缺失值,只有类型为
bool
、str
、int
或float
的列,并且每列只有一个类型,如输入格式中所述。 -
包含搜索配置的JSON文件。此配置应包含
name
:实验的名称。将创建一个具有此名称的文件夹。num_random_search
:超参数搜索中的迭代次数。train_csv
:包含数据集的csv文件的路径。continuous_cols
:要考虑为连续的列索引列表,从0开始。epoch
:训练模型时的epoch数。steps_per_epoch
:每个epoch中的优化步数。sample_rows
:评估模型时要采样的行数。
您可以在examples/config.json中看到一个这样的json文件的示例,您可以下载并用作模板。
执行
一旦我们准备好了所有东西,我们可以使用以下命令启动随机超参数搜索
tgan experiments config.json results.json
其中第一个参数config.json
是配置JSON的路径,第二个参数results.json
是存储执行摘要的路径。
这将运行随机搜索,它基本上包括以下步骤
- 我们获取并划分我们的数据以供测试和训练使用。
- 我们随机选择要测试的超参数。
- 然后,对于每个超参数组合,我们使用真实的训练数据T训练一个TGAN模型并生成一个合成的训练数据集Tsynth。
- 然后,我们在真实和合成的数据集上训练机器学习模型。
- 我们使用这些训练好的模型在真实测试数据上运行,并查看它们的性能。
输出
实验完成后,以下内容可以得到
- 一个JSON文件,在上述示例中称为
results.json
,包含实验的摘要。此JSON将包含每个实验的键name
,并且在该键上,将包含长度为num_random_search
的数组,其中包含选定的参数及其评估分数。对于像示例这样的配置,摘要将如下所示
{
'census': [
{
"steps_per_epoch" : 10000,
"num_gen_feature" : 300,
"num_dis_hidden" : 300,
"batch_size" : 100,
"num_gen_rnn" : 400,
"score" : 0.937802280415988,
"max_epoch" : 5,
"num_dis_layers" : 4,
"learning_rate" : 0.0002,
"z_dim" : 100,
"noise" : 0.2
},
... # 9 more nodes
]
}
- 一组文件夹,每个文件夹的名称都按照JSON配置文件中指定的
name
命名,包含在experiments
文件夹中。在每个文件夹中,可以找到采样数据和模型。对于一个像示例中的配置,它将看起来像这样
experiments/
census/
data/ # Sampled data with each of the models in the random search.
model_0/
logs/ # Training logs
model/ # Tensorflow model checkpoints
model_1/ # 9 more folders, one for each model in the random search
...
研究
第一个 TAGN 版本作为 Lei Xu 和 Kalyan Veeramachaneni 所著的 《使用生成对抗网络合成表格数据》 研究论文的支持软件构建。
论文中提到的软件的精确版本可以在发布部分以 研究预发布 的形式找到。
引用
如果您使用了 TGAN,请引用以下工作
Lei Xu, Kalyan Veeramachaneni. 2018. 使用生成对抗网络合成表格数据。
@article{xu2018synthesizing,
title={Synthesizing Tabular Data using Generative Adversarial Networks},
author={Xu, Lei and Veeramachaneni, Kalyan},
journal={arXiv preprint arXiv:1811.11264},
year={2018}
}
您可以在这里找到原始论文。
历史
0.1.0
- 首次发布在 PyPI 上。
项目详情
下载文件
下载适用于您平台的应用程序文件。如果您不确定选择哪个,请了解更多关于 安装软件包 的信息。
源分发
构建分发
tgan-0.1.0.tar.gz 的散列
算法 | 散列摘要 | |
---|---|---|
SHA256 | 20f5335815cf587f7ef9ac81a784dccecfab77a64d99319f9776070efa302e84 |
|
MD5 | 8b0b9b0c912edcf94d5c0e082571669c |
|
BLAKE2b-256 | 9540c25a8a9663c9b7d52fd531207347ef8f2dc2b4fa00cd9ae15299fb3647ff |
tgan-0.1.0-py2.py3-none-any.whl 的散列
算法 | 散列摘要 | |
---|---|---|
SHA256 | 76a42879e0969826e57c9cbcc5eb8b1d5c3c3c37624a91b4faa45e9c78f56fcd |
|
MD5 | 36431430dc6ca7f699cab0301d587c7b |
|
BLAKE2b-256 | 90e13174c7fb5e2bc6b28771797e74af060f05cb7275aa222e40be56108c303e |