跳转到主要内容

使用神经网络进行条件平均处理效应估计

项目描述

CATENets - 使用神经网络进行条件平均处理效应估计

CATENets Tests Documentation Status License

代码作者: Alicia Curth (amc253@cam.ac.uk)

此存储库包含基于Jax的、sklearn风格的神经网络条件平均处理效应(CATE)估计器的实现,这些实现用于AISTATS21论文"非参数估计异质处理效应:从理论到学习算法"(Curth & vd Schaar,2021a)以及后续NeurIPS21论文"关于异质处理效应估计的归纳偏差"(Curth & vd Schaar,2021b)和NeurIPS21数据集和基准测试论文"在估计CATE方面真的做得很好?对治疗效应估计中机器学习基准测试实践的批判性审视"(Curth等人,2021)。

我们实现了Curth & vd Schaar(2021a)中介绍的SNet类,以及Curth & vd Schaar(2021b)中讨论的FlexTENet和OffsetNet,并重新实现了现有文献中的一些基于NN的算法(Shalit等人(2017),Shi等人(2019),Hassanpour & Greiner(2020))。我们还提供了用于CATE估计的许多所谓元学习器的基于NN的实例化,包括两步伪结果回归估计器(DR-学习器(Kennedy,2020)和单稳健倾向得分(PW)以及回归调整(RA)学习器),Nie & Wager(2017)的R-学习器和Kuenzel等人(2019)的X-学习器。所有列出的论文中使用了catenets.models.jax中的jax实现;此外,一些模型的pytorch版本(catenets.models.torch)由Bogdan Cebere提供。

接口

仓库包含一个名为catenets的包,其中包含用于建模和评估的所有通用代码,以及一个名为experiments的文件夹,其中包含复制实验结果的代码。在catenets中实现的全部学习算法(SNet, FlexTENet, OffsetNet, TNet, SNet1 (TARNet), SNet2 (DragonNet), SNet3, DRNet, RANet, PWNet, RNet, XNet)都带有sklearn风格的包装器,实现了.fit(X, y, w).predict(X)方法,其中predict默认返回CATE。所有超参数都在catenets.models文件夹中的相应文件中详细说明。

示例用法

from catenets.models.jax import TNet, SNet
from catenets.experiment_utils.simulation_utils import simulate_treatment_setup

# simulate some data (here: unconfounded, 10 prognostic variables and 5 predictive variables)
X, y, w, p, cate = simulate_treatment_setup(n=2000, n_o=10, n_t=5, n_c=0)

# estimate CATE using TNet
t = TNet()
t.fit(X, y, w)
cate_pred_t = t.predict(X)  # without potential outcomes
cate_pred_t, po0_pred_t, po1_pred_t = t.predict(X, return_po=True)  # predict potential outcomes too

# estimate CATE using SNet
s = SNet(penalty_orthogonal=0.01)
s.fit(X, y, w)
cate_pred_s = s.predict(X)

可以使用此存储库复制Curth & vd Schaar(2021a)中的所有实验;必要的代码在experiments.experiments_AISTATS21中。要从shell中这样做,克隆存储库,创建一个新的虚拟环境并运行

pip install -r requirements.txt #install requirements
python run_experiments_AISTATS.py
Options:
--experiment # defaults to 'simulation', 'ihdp' will run ihdp experiments
--setting # different simulation settings in synthetic experiments (can be 1-5)
--models # defaults to None which will train all models considered in paper,
         # can be string of model name (e.g 'TNet'), 'plug' for all plugin models,
         # 'pseudo' for all pseudo-outcome regression models

--file_name # base file name to write to, defaults to 'results'
--n_repeats # number of experiments to run for each configuration, defaults to 10 (should be set to 100 for IHDP)

同样,Curth & vd Schaar(2021b)中的实验可以使用experiments.experiments_inductivebias_NeurIPS21中的代码复制(或从shell使用python run_experiments_inductive_bias_NeurIPS.py)和Curth等人(2021)中的实验可以使用experiments.experiments_benchmarks_NeurIPS21中的代码复制(也可以从shell使用python run_experiments_benchmarks_NeurIPS运行catenets实验)。

代码还可以作为Python包安装(catenets)。从存储库的本地副本运行python setup.py install

注意:目前jax仅在macOS和linux上受支持,但可以通过WSL(Windows子系统Linux)在Windows上运行。

引用

如果您使用此软件,请引用相应的论文

@inproceedings{curth2021nonparametric,
  title={Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning Algorithms},
  author={Curth, Alicia and van der Schaar, Mihaela},
    year={2021},
  booktitle={Proceedings of the 24th International Conference on Artificial
  Intelligence and Statistics (AISTATS)},
  organization={PMLR}
}

@article{curth2021inductive,
  title={On Inductive Biases for Heterogeneous Treatment Effect Estimation},
  author={Curth, Alicia and van der Schaar, Mihaela},
  booktitle={Proceedings of the Thirty-Fifth Conference on Neural Information Processing Systems},
  year={2021}
}


@article{curth2021really,
  title={Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation},
  author={Curth, Alicia and Svensson, David and Weatherall, James and van der Schaar, Mihaela},
  booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks},
  year={2021}
}

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分布

此版本没有可用的源分布文件。请参阅生成发行存档的教程

构建发行版

catenets-0.2.3-py3-none-macosx_10_14_x86_64.whl (130.2 kB 查看散列)

上传时间 Python 3 macOS 10.14+ x86-64

catenets-0.2.3-py3-none-any.whl (131.3 kB 查看散列)

上传时间 Python 3

支持者