PyTorch对FAENet的实现,来自《FAENet:用于材料建模的框架平均等变GNN》
项目描述
FAENet:用于材料建模的框架平均等变GNN
此存储库包含了一篇论文的实现,该论文被ICML 2023接受,即《FAENet:用于材料建模的框架平均等变GNN》。更具体地说,您将找到以下内容:
FrameAveraging
:将您的pytorch-geometric数据投影到所有欧几里得变换的规范空间中的转换,如论文中定义。FAENet
:一种用于材料建模的GNN架构。model_forward
:一个高级前向函数,用于计算适用于框架平均方法的对等模型预测,即处理不同的框架并将其映射到对等预测。
此外: https://github.com/vict0rsch/faenet
安装
pip install faenet
⚠️ 上述安装需要 Python >= 3.8
,torch > 1.11
,torch_geometric > 2.1
,据我们所知。还需要mendeleev
和pandas
包来在FAENet中推导物理感知原子嵌入。
入门
框架平均转换
FrameAveraging
是一种适用于 pytorch-geometric Data
对象的变换方法,应在您的 Dataset
类的 get_item()
函数中使用。该方法为原子图推导出一个新的规范位置,对于所有欧几里得对称性都是相同的,并将其存储在数据属性 fa_pos
下。您可以从多种帧平均选项中进行选择,从 全帧平均 (Full FA) 到 随机帧平均 (Stochastic FA)(2D 或 3D),包括具有旋转样本的传统数据增强 DA。有关更多详细信息,请参阅完整的 文档。请注意,尽管此变换针对 pytorch-geometric 数据对象,但由于其核心函数 frame_averaging_2D()
和 frame_averaging_3D()
可以推广到其他数据格式,因此它可以很容易地扩展到新的设置。
import torch
from faenet.transforms import FrameAveraging
frame_averaging = "3D" # symmetry preservation method used: {"3D", "2D", "DA", ""}:
fa_method = "stochastic" # the frame averaging method: {"det", "all", "se3-stochastic", "se3-det", "se3-all", ""}:
transform = FrameAveraging(frame_averaging, fa_method)
transform(data) # transform the PyG graph data
帧平均的模型正向传播
model_forward()
聚合了在应用帧平均时选择的机器学习模型(例如 FAENet)的预测,如论文中的公式(1)所述。显然,将模型直接应用于规范位置(fa_pos
)不会产生等变预测。此方法必须在训练和推理时间应用以计算所有模型预测。它需要 batch
具有位置、批量和帧平均属性(请参阅 文档)。
from faenet.fa_forward import model_forward
preds = model_forward(
batch=batch, # batch from, dataloader
model=model, # FAENet(**kwargs)
frame_averaging="3D", # ["2D", "3D", "DA", ""]
mode="train", # for training
crystal_task=True, # for crystals, with pbc conditions
)
FAENet GNN
FAENet GNN 模型的实现,兼容任何数据集或变换。简而言之,FAENet 是一个非常简单、可扩展且具有表现力的模型。由于它不显式保留数据对称性,因此它具有直接和不受限制地处理原子相对位置的能力,这非常高效且强大。尽管它专门设计为与上述帧平均一起应用,以保留对称性而没有任何设计限制,但请注意,它也可以不与帧平均一起应用。当与帧平均一起应用时,我们需要使用上面的 model_forward()
函数来计算模型预测,model(data)
不够。请注意,此处未给出训练过程,您应参阅原始的 github 仓库。检查 文档 以查看所有输入参数。
请注意,模型假定输入数据(例如下面的 batch
)具有某些属性,如原子序数、批量和位置或边索引。如果您的数据没有这些属性,您可以使用自定义预处理函数,以 utils.py 中的 pbc_preprocess
或 base_preprocess
为灵感。您只需将它们作为参数传递给 FAENet(preprocess
)即可。
from faenet.model import FAENet
preds = FAENet(**kwargs)
model(batch)
评估
eval_model_symmetries()
函数有助于您评估模型的等变性、不变性和其他属性,就像我们在论文中所做的那样。
注意:您可以预测任何原子级或图级属性,尽管代码明确提到了能量和力。
测试
“/tests” 文件夹包含几个有用的单元测试。请随意查看它们以了解如何使用模型。有关更高级的示例,请参阅我们在 ICML 论文中使用的完整 仓库,用于在 OC20 IS2RE、S2EF、QM9 和 QM7-X 数据集上做出预测。
这需要 poetry
。在您运行测试之前,请确保您的环境中已安装 torch
和 torch_geometric
。遗憾的是,由于 CUDA/torch 兼容性, neither torch
nor torch_geometric
是显式依赖项的一部分,并且必须单独安装。
git clone git@github.com:vict0rsch/faenet.git
poetry install --with dev
pytest --cov=faenet --cov-report term-missing
在 Mac 上进行测试时,您可能会遇到 库未加载错误
联系方式
作者:Alexandre Duval (alexandre.duval@mila.quebec) 和 Victor Schmidt (schmidtv@mila.quebec)。欢迎您通过电子邮件或GitHub问题反馈提问和建议。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。
源分发
构建分发
faenet-0.1.3.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 0b494464f4e45c14677a5a0cd4112301995501ce0d5f039be184dea605e92a92 |
|
MD5 | 73726aafa8cf85c466e8079fd54949c4 |
|
BLAKE2b-256 | b010bba53c1695205e0ae59ac86608e88992474a1014536523c52ca22012864e |
faenet-0.1.3-py3-none-any.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 4153255e05e3303e7b9f71145ff06fb6ea9ef2c6f4e2b4c6c7bc38bbfe4055aa |
|
MD5 | 2140214e294f3289a2955fb23218455c |
|
BLAKE2b-256 | 943bc578787d2d1796c19f955ef44241d3ecc9a9195f3f3e06bfb09d6d29fb30 |