使用神经网络进行条件平均处理效应估计
项目描述
CATENets - 使用神经网络进行条件平均处理效应估计
代码作者: Alicia Curth (amc253@cam.ac.uk)
此存储库包含基于Jax的、sklearn风格的神经网络条件平均处理效应(CATE)估计器的实现,这些实现用于AISTATS21论文"非参数估计异质处理效应:从理论到学习算法"(Curth & vd Schaar,2021a)以及后续NeurIPS21论文"关于异质处理效应估计的归纳偏差"(Curth & vd Schaar,2021b)和NeurIPS21数据集和基准测试论文"在估计CATE方面真的做得很好?对治疗效应估计中机器学习基准测试实践的批判性审视"(Curth等人,2021)。
我们实现了Curth & vd Schaar(2021a)中介绍的SNet类,以及Curth & vd Schaar(2021b)中讨论的FlexTENet和OffsetNet,并重新实现了现有文献中的一些基于NN的算法(Shalit等人(2017),Shi等人(2019),Hassanpour & Greiner(2020))。我们还提供了用于CATE估计的许多所谓元学习器的基于NN的实例化,包括两步伪结果回归估计器(DR-学习器(Kennedy,2020)和单稳健倾向得分(PW)以及回归调整(RA)学习器),Nie & Wager(2017)的R-学习器和Kuenzel等人(2019)的X-学习器。所有列出的论文中使用了catenets.models.jax
中的jax实现;此外,一些模型的pytorch版本(catenets.models.torch
)由Bogdan Cebere提供。
接口
仓库包含一个名为catenets
的包,其中包含用于建模和评估的所有通用代码,以及一个名为experiments
的文件夹,其中包含复制实验结果的代码。在catenets
中实现的全部学习算法(SNet, FlexTENet, OffsetNet, TNet, SNet1 (TARNet), SNet2 (DragonNet), SNet3, DRNet, RANet, PWNet, RNet, XNet
)都带有sklearn风格的包装器,实现了.fit(X, y, w)
和.predict(X)
方法,其中predict默认返回CATE。所有超参数都在catenets.models
文件夹中的相应文件中详细说明。
示例用法
from catenets.models.jax import TNet, SNet
from catenets.experiment_utils.simulation_utils import simulate_treatment_setup
# simulate some data (here: unconfounded, 10 prognostic variables and 5 predictive variables)
X, y, w, p, cate = simulate_treatment_setup(n=2000, n_o=10, n_t=5, n_c=0)
# estimate CATE using TNet
t = TNet()
t.fit(X, y, w)
cate_pred_t = t.predict(X) # without potential outcomes
cate_pred_t, po0_pred_t, po1_pred_t = t.predict(X, return_po=True) # predict potential outcomes too
# estimate CATE using SNet
s = SNet(penalty_orthogonal=0.01)
s.fit(X, y, w)
cate_pred_s = s.predict(X)
可以使用此存储库复制Curth & vd Schaar(2021a)中的所有实验;必要的代码在experiments.experiments_AISTATS21
中。要从shell中这样做,克隆存储库,创建一个新的虚拟环境并运行
pip install -r requirements.txt #install requirements
python run_experiments_AISTATS.py
Options:
--experiment # defaults to 'simulation', 'ihdp' will run ihdp experiments
--setting # different simulation settings in synthetic experiments (can be 1-5)
--models # defaults to None which will train all models considered in paper,
# can be string of model name (e.g 'TNet'), 'plug' for all plugin models,
# 'pseudo' for all pseudo-outcome regression models
--file_name # base file name to write to, defaults to 'results'
--n_repeats # number of experiments to run for each configuration, defaults to 10 (should be set to 100 for IHDP)
同样,Curth & vd Schaar(2021b)中的实验可以使用experiments.experiments_inductivebias_NeurIPS21
中的代码复制(或从shell使用python run_experiments_inductive_bias_NeurIPS.py
)和Curth等人(2021)中的实验可以使用experiments.experiments_benchmarks_NeurIPS21
中的代码复制(也可以从shell使用python run_experiments_benchmarks_NeurIPS
运行catenets实验)。
代码还可以作为Python包安装(catenets
)。从存储库的本地副本运行python setup.py install
。
注意:目前jax仅在macOS和linux上受支持,但可以通过WSL(Windows子系统Linux)在Windows上运行。
引用
如果您使用此软件,请引用相应的论文
@inproceedings{curth2021nonparametric,
title={Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning Algorithms},
author={Curth, Alicia and van der Schaar, Mihaela},
year={2021},
booktitle={Proceedings of the 24th International Conference on Artificial
Intelligence and Statistics (AISTATS)},
organization={PMLR}
}
@article{curth2021inductive,
title={On Inductive Biases for Heterogeneous Treatment Effect Estimation},
author={Curth, Alicia and van der Schaar, Mihaela},
booktitle={Proceedings of the Thirty-Fifth Conference on Neural Information Processing Systems},
year={2021}
}
@article{curth2021really,
title={Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation},
author={Curth, Alicia and Svensson, David and Weatherall, James and van der Schaar, Mihaela},
booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks},
year={2021}
}
项目详情
下载文件
下载适合您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。
源分布
构建发行版
哈希值 用于 catenets-0.2.3-py3-none-macosx_10_14_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | beee8538fa5e4c18f5b2e78b850fabea0d5091e7950f4fed8441481d08738569 |
|
MD5 | c6d794dce6ec574b0fb10c02ee36b11e |
|
BLAKE2b-256 | 51169b83a987c878dc3a41bc44011e314a1602b8e9fde4fa681405380cfb076d |