跳转到主要内容

表示学习模型、技术、回调、实用工具的集合,用于创建细胞形状、形态和细胞内组织的潜在变量模型。

项目描述

CytoDL

PyTorch Lightning Config: Hydra Template

描述

作为艾伦细胞科学研究所使命的一部分,旨在理解人类诱导多能干细胞通过哪些原理建立和维护细胞结构的稳健动态定位,CytoDL旨在统一理解二维和三维生物数据(如图像、点云和表格数据)的深度学习方法。

CytoDL的大部分底层结构基于lightning-hydra-template组织 - 我们强烈建议您熟悉他们的(简短)文档,以获取运行训练、覆盖等详细说明。

目前可用的代码大致分为两个领域:图像到图像的转换和表示学习。图像到图像的代码(表示为im2im)包含配置文件,详细说明了如何使用条件生成对抗网络(GAN)进行分辨率增强(例如,从20x图像预测100x图像)、语义和实例分割,以及无标签预测。我们还提供了用于在2D和3D图像上使用视觉Transformer(ViT)主干进行Masked Autoencoder(MAE)和联合嵌入预测架构(JEPA)预训练的配置,以及从这些预训练特征中训练分割解码器的配置。表示学习代码包括各种变分自编码器(VAE)架构和对比学习方法,如VICReg。由于依赖性问题,目前不支持在Windows上使用等变自编码器。

由于我们依赖于最新的pytorch版本,希望在使用GPU硬件上训练和运行模型的用户需要最新的NVIDIA驱动程序。拥有较老GPU的用户不应期望代码能够直接运行。同样,我们目前不支持在Mac GPU上训练/预测。在大多数情况下,当GPU训练失败时,基于CPU的训练应该可以工作。

对于im2im模型,我们提供了一些示例3D图像,用于训练基本的图像到图像转换类型的模型,并为用户提供默认模型配置文件,以便用户熟悉框架并准备在他们自己的数据上训练和应用这些模型。请注意,这些默认模型非常小,并在高度下采样数据上训练,以便测试能够高效运行 - 为了最佳性能,应增加模型大小并从数据配置中删除下采样。

运行方法

安装依赖项。依赖项是平台特定的,请将PLATFORM替换为您的平台 - 即linuxwindowsmac

# clone project
git clone https://github.com/AllenCellModeling/cyto-dl
cd cyto-dl

# [OPTIONAL] create conda environment
conda create -n myenv python=3.9
conda activate myenv

pip install -r requirements/PLATFORM/requirements.txt

# [OPTIONAL] install extra dependencies - equivariance related
pip install -r requirements/PLATFORM/equiv-requirements.txt

pip install -e .


#[OPTIONAL] if you want to use default experiments on example data
python scripts/download_test_data.py

API

from cyto_dl.api import CytoDLModel

model = CytoDLModel()
model.download_example_data()
model.load_default_experiment("segmentation", output_dir="./output", overrides=["trainer=cpu"])
model.print_config()
model.train()

# [OPTIONAL] async training
await model.train(run_async=True)

大多数模型通过在数据配置中传递数据路径来工作。对于在内存中已经存在的数据集进行训练或预测,您可以直接将数据传递给模型。请注意,此用例主要用于程序性使用(例如,在工作流程或Jupyter笔记本中),而不是通过正常的CLI。一个用于此用例的可能配置设置实验配置使用im2im/segmentation_array实验进行演示。对于训练,数据必须作为包含键“train”和“val”的字典传递,其中键对应于数据配置。

from cyto_dl.api import CytoDLModel
import numpy as np

model = CytoDLModel()
model.load_default_experiment("segmentation_array", output_dir="./output")
model.print_config()

# create CZYX dummy data
data = {
    "train": [{"raw": np.random.randn(1, 40, 256, 256), "seg": np.ones((1, 40, 256, 256))}],
    "val": [{"raw": np.random.randn(1, 40, 256, 256), "seg": np.ones((1, 40, 256, 256))}],
}
model.train(data=data)

对于预测,数据必须作为numpy数组列表传递。结果预测将以字典形式处理,其中每个键对应于模型配置中的任务头,对应的值按BC(Z)YX顺序排列。

from cyto_dl.api import CytoDLModel
import numpy as np
from cyto_dl.utils import extract_array_predictions

model = CytoDLModel()
model.load_default_experiment(
    "segmentation_array", output_dir="./output", overrides=["data=im2im/numpy_dataloader_predict"]
)
model.print_config()

# create CZYX dummy data
data = [np.random.rand(1, 32, 64, 64), np.random.rand(1, 32, 64, 64)]

_, _, output = model.predict(data=data)
preds = extract_array_predictions(output)

使用configs/experiment/中选择的实验配置训练模型

#gpu
python cyto_dl/train.py experiment=im2im/experiment_name.yaml trainer=gpu

#cpu
python cyto_dl/train.py experiment=im2im/experiment_name.yaml trainer=cpu

您可以通过以下方式从命令行覆盖任何参数

python cyto_dl/train.py trainer.max_epochs=20 datamodule.batch_size=64

项目详情


下载文件

下载您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分布

cyto-dl-0.4.1.tar.gz (426.4 KB 查看哈希值

上传时间

构建分布

cyto_dl-0.4.1-py3-none-any.whl (294.5 kB 查看哈希值)

上传时间: Python 3

由以下支持

AWSAWS云计算和安全赞助商DatadogDatadog监控FastlyFastlyCDNGoogleGoogle下载分析MicrosoftMicrosoftPSF赞助商PingdomPingdom监控SentrySentry错误日志StatusPageStatusPage状态页