跳转到主要内容

PyTorch对FAENet的实现,来自《FAENet:用于材料建模的框架平均等变GNN》

项目描述

💻  代码   •   文档  📑

Python Documentation Status


FAENet:用于材料建模的框架平均等变GNN

此存储库包含了一篇论文的实现,该论文被ICML 2023接受,即《FAENet:用于材料建模的框架平均等变GNN》。更具体地说,您将找到以下内容:

  • FrameAveraging:将您的pytorch-geometric数据投影到所有欧几里得变换的规范空间中的转换,如论文中定义。
  • FAENet:一种用于材料建模的GNN架构。
  • model_forward:一个高级前向函数,用于计算适用于框架平均方法的对等模型预测,即处理不同的框架并将其映射到对等预测。

此外: https://github.com/vict0rsch/faenet

安装

pip install faenet

⚠️ 上述安装需要 Python >= 3.8torch > 1.11torch_geometric > 2.1,据我们所知。还需要mendeleevpandas包来在FAENet中推导物理感知原子嵌入。

入门

框架平均转换

FrameAveraging 是一种适用于 pytorch-geometric Data 对象的变换方法,应在您的 Dataset 类的 get_item() 函数中使用。该方法为原子图推导出一个新的规范位置,对于所有欧几里得对称性都是相同的,并将其存储在数据属性 fa_pos 下。您可以从多种帧平均选项中进行选择,从 全帧平均 (Full FA)随机帧平均 (Stochastic FA)(2D 或 3D),包括具有旋转样本的传统数据增强 DA。有关更多详细信息,请参阅完整的 文档。请注意,尽管此变换针对 pytorch-geometric 数据对象,但由于其核心函数 frame_averaging_2D()frame_averaging_3D() 可以推广到其他数据格式,因此它可以很容易地扩展到新的设置。

import torch
from faenet.transforms import FrameAveraging

frame_averaging = "3D"  # symmetry preservation method used: {"3D", "2D", "DA", ""}:
fa_method = "stochastic"  # the frame averaging method: {"det", "all", "se3-stochastic", "se3-det", "se3-all", ""}:
transform = FrameAveraging(frame_averaging, fa_method)
transform(data)  # transform the PyG graph data

帧平均的模型正向传播

model_forward() 聚合了在应用帧平均时选择的机器学习模型(例如 FAENet)的预测,如论文中的公式(1)所述。显然,将模型直接应用于规范位置(fa_pos)不会产生等变预测。此方法必须在训练和推理时间应用以计算所有模型预测。它需要 batch 具有位置、批量和帧平均属性(请参阅 文档)。

from faenet.fa_forward import model_forward

preds = model_forward(
    batch=batch,   # batch from, dataloader
    model=model,  # FAENet(**kwargs)
    frame_averaging="3D", # ["2D", "3D", "DA", ""]
    mode="train",  # for training 
    crystal_task=True,  # for crystals, with pbc conditions
)

FAENet GNN

FAENet GNN 模型的实现,兼容任何数据集或变换。简而言之,FAENet 是一个非常简单、可扩展且具有表现力的模型。由于它不显式保留数据对称性,因此它具有直接和不受限制地处理原子相对位置的能力,这非常高效且强大。尽管它专门设计为与上述帧平均一起应用,以保留对称性而没有任何设计限制,但请注意,它也可以不与帧平均一起应用。当与帧平均一起应用时,我们需要使用上面的 model_forward() 函数来计算模型预测,model(data) 不够。请注意,此处未给出训练过程,您应参阅原始的 github 仓库。检查 文档 以查看所有输入参数。

请注意,模型假定输入数据(例如下面的 batch)具有某些属性,如原子序数、批量和位置或边索引。如果您的数据没有这些属性,您可以使用自定义预处理函数,以 utils.py 中的 pbc_preprocessbase_preprocess 为灵感。您只需将它们作为参数传递给 FAENet(preprocess)即可。

from faenet.model import FAENet

preds = FAENet(**kwargs)
model(batch)

FAENet architecture

评估

eval_model_symmetries() 函数有助于您评估模型的等变性、不变性和其他属性,就像我们在论文中所做的那样。

注意:您可以预测任何原子级或图级属性,尽管代码明确提到了能量和力。

测试

“/tests” 文件夹包含几个有用的单元测试。请随意查看它们以了解如何使用模型。有关更高级的示例,请参阅我们在 ICML 论文中使用的完整 仓库,用于在 OC20 IS2RE、S2EF、QM9 和 QM7-X 数据集上做出预测。

这需要 poetry。在您运行测试之前,请确保您的环境中已安装 torchtorch_geometric。遗憾的是,由于 CUDA/torch 兼容性, neither torch nor torch_geometric 是显式依赖项的一部分,并且必须单独安装。

git clone git@github.com:vict0rsch/faenet.git
poetry install --with dev
pytest --cov=faenet --cov-report term-missing

在 Mac 上进行测试时,您可能会遇到 库未加载错误

联系方式

作者:Alexandre Duval (alexandre.duval@mila.quebec) 和 Victor Schmidt (schmidtv@mila.quebec)。欢迎您通过电子邮件或GitHub问题反馈提问和建议。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源分发

faenet-0.1.3.tar.gz (23.2 kB 查看哈希值)

上传时间

构建分发

faenet-0.1.3-py3-none-any.whl (23.5 kB 查看哈希值)

上传时间 Python 3

由以下支持