跳转到主要内容

基于预训练骨干的图像分割模型。PyTorch。

项目描述

logo
Python库,包含用于图像的神经网络
基于PyTorch的分割。

PyPI version Build Status Documentation Status
Downloads Generic badge

该库的主要特性包括

  • 高级API(仅两行即可创建神经网络)
  • 8种模型架构用于二值和多类分割(包括传奇的Unet)
  • 99种可用的编码器
  • 所有编码器都有预训练的权重,以实现更快和更好的收敛

📚 项目文档 📚

访问Read The Docs项目页面或阅读以下README以了解更多关于Segmentation Models Pytorch(简称SMP)库的信息

📋 目录

  1. 快速入门
  2. 示例
  3. 模型
    1. 架构
    2. 编码器
  4. 模型API
    1. 输入通道
    2. 辅助分类输出
    3. 深度
  5. 安装
  6. 使用该库赢得的竞赛
  7. 贡献
  8. 引用
  9. 许可证

⏳ 快速入门

1. 使用SMP创建您的第一个分割模型

分割模型只是一个PyTorch nn.Module,它可以像下面这样轻松创建

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pretreined weights for encoder initialization
    in_channels=1,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)
  • 查看表格,了解可用的模型架构
  • 查看表格,了解可用的编码器及其相应的权重

2. 配置数据预处理

所有编码器都有预训练的权重。与权重预训练期间相同的方式准备您的数据可能会给您带来更好的结果(更高的指标得分和更快的收敛)。但这只适用于1-2-3通道图像,并且对于您训练整个模型(而不仅仅是解码器),通常是不必要的。

from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

恭喜!您已完成!现在您可以使用您喜欢的框架训练您的模型了!

💡 示例

  • 在CamVid数据集上训练汽车分割模型这里
  • 使用Catalyst(PyTorch的高级框架)、CatalystTTAch(PyTorch的TTA库)和Albumentations(快速图像增强库)训练SMP模型 - 这里 Open In Colab

📦 模型

架构

编码器

以下是在SMP中支持的编码器列表。选择适当的编码器系列,点击展开表格并选择特定的编码器和其预训练权重(encoder_nameencoder_weights 参数)。

ResNet
编码器 权重 参数,M
resnet18 imagenet / ssl / swsl 11M
resnet34 imagenet 21M
resnet50 imagenet / ssl / swsl 23M
resnet101 imagenet 42M
resnet152 imagenet 58M
ResNeXt
编码器 权重 参数,M
resnext50_32x4d imagenet / ssl / swsl 22M
resnext101_32x4d ssl / swsl 42M
resnext101_32x8d imagenet / instagram / ssl / swsl 86M
resnext101_32x16d instagram / ssl / swsl 191M
resnext101_32x32d instagram 466M
resnext101_32x48d instagram 826M
ResNeSt
编码器 权重 参数,M
timm-resnest14d imagenet 8M
timm-resnest26d imagenet 15M
timm-resnest50d imagenet 25M
timm-resnest101e imagenet 46M
timm-resnest200e imagenet 68M
timm-resnest269e imagenet 108M
timm-resnest50d_4s2x40d imagenet 28M
timm-resnest50d_1s4x24d imagenet 23M
Res2Ne(X)t
编码器 权重 参数,M
timm-res2net50_26w_4s imagenet 23M
timm-res2net101_26w_4s imagenet 43M
timm-res2net50_26w_6s imagenet 35M
timm-res2net50_26w_8s imagenet 46M
timm-res2net50_48w_2s imagenet 23M
timm-res2net50_14w_8s imagenet 23M
timm-res2next50 imagenet 22M
RegNet(x/y)
编码器 权重 参数,M
timm-regnetx_002 imagenet 2M
timm-regnetx_004 imagenet 4M
timm-regnetx_006 imagenet 5M
timm-regnetx_008 imagenet 6M
timm-regnetx_016 imagenet 8M
timm-regnetx_032 imagenet 14M
timm-regnetx_040 imagenet 20M
timm-regnetx_064 imagenet 24M
timm-regnetx_080 imagenet 37M
timm-regnetx_120 imagenet 43M
timm-regnetx_160 imagenet 52M
timm-regnetx_320 imagenet 105M
timm-regnety_002 imagenet 2M
timm-regnety_004 imagenet 3M
timm-regnety_006 imagenet 5M
timm-regnety_008 imagenet 5M
timm-regnety_016 imagenet 10M
timm-regnety_032 imagenet 17M
timm-regnety_040 imagenet 19M
timm-regnety_064 imagenet 29M
timm-regnety_080 imagenet 37M
timm-regnety_120 imagenet 49M
timm-regnety_160 imagenet 80M
timm-regnety_320 imagenet 141M
SE-Net
编码器 权重 参数,M
senet154 imagenet 113M
se_resnet50 imagenet 26M
se_resnet101 imagenet 47M
se_resnet152 imagenet 64M
se_resnext50_32x4d imagenet 25M
se_resnext101_32x4d imagenet 46M
SK-ResNe(X)t
编码器 权重 参数,M
timm-skresnet18 imagenet 11M
timm-skresnet34 imagenet 21M
timm-skresnext50_32x4d imagenet 25M
DenseNet
编码器 权重 参数,M
densenet121 imagenet 6M
密集连接169 imagenet 12M
密集连接201 imagenet 18M
密集连接161 imagenet 26M
Inception
编码器 权重 参数,M
inceptionresnetv2 imagenet / imagenet+background 54M
inceptionv4 imagenet / imagenet+background 41M
xception imagenet 22M
EfficientNet
编码器 权重 参数,M
efficientnet-b0 imagenet 4M
efficientnet-b1 imagenet 6M
efficientnet-b2 imagenet 7M
efficientnet-b3 imagenet 10M
efficientnet-b4 imagenet 17M
efficientnet-b5 imagenet 28M
efficientnet-b6 imagenet 40M
efficientnet-b7 imagenet 63M
timm-efficientnet-b0 imagenet / advprop / noisy-student 4M
timm-efficientnet-b1 imagenet / advprop / noisy-student 6M
timm-efficientnet-b2 imagenet / advprop / noisy-student 7M
timm-efficientnet-b3 imagenet / advprop / noisy-student 10M
timm-efficientnet-b4 imagenet / advprop / noisy-student 17M
timm-efficientnet-b5 imagenet / advprop / noisy-student 28M
timm-efficientnet-b6 imagenet / advprop / noisy-student 40M
timm-efficientnet-b7 imagenet / advprop / noisy-student 63M
timm-efficientnet-b8 imagenet / advprop 84M
timm-efficientnet-l2 noisy-student 474M
MobileNet
编码器 权重 参数,M
mobilenet_v2 imagenet 2M
DPN
编码器 权重 参数,M
dpn68 imagenet 11M
dpn68b imagenet+5k 11M
dpn92 imagenet+5k 34M
dpn98 imagenet 58M
dpn107 imagenet+5k 84M
dpn131 imagenet 76M
VGG
编码器 权重 参数,M
vgg11 imagenet 9M
vgg11_bn imagenet 9M
vgg13 imagenet 9M
vgg13_bn imagenet 9M
vgg16 imagenet 14M
vgg16_bn imagenet 14M
vgg19 imagenet 20M
vgg19_bn imagenet 20M

* ssl, swsl - 在ImageNet上的半监督和弱监督学习 (repo).

🔁 模型API

  • model.encoder - 预训练骨干网络,用于提取不同空间分辨率的特征
  • model.decoder - 根据模型架构(Unet/Linknet/PSPNet/FPN)而定
  • model.segmentation_head - 最后的块,用于生成所需数量的掩码通道(包括可选的上采样和激活)
  • model.classification_head - 可选的块,在编码器顶部创建分类头
  • model.forward(x) - 依次将x通过模型的编码器、解码器和分割头(以及指定的分类头)
输入通道

输入通道参数允许您创建处理具有任意数量通道的张量的模型。如果您使用从imagenet预训练的权重,则第一个卷积层的权重将被重用于1或2个通道的输入,对于输入通道>4,第一个卷积层的权重将被随机初始化。

model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
辅助分类输出

所有模型都支持aux_params参数,默认设置为None。如果aux_params = None,则不创建分类辅助输出,否则模型将不仅生成mask输出,还生成形状为NClabel输出。分类头由GlobalPooling->Dropout(可选)->Linear->Activation(可选)层组成,可以通过aux_params进行如下配置

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.5,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=4,                 # define number of output labels
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)
深度

深度参数指定编码器中下采样操作的数量,因此您可以通过指定较小的depth来使模型变轻。

model = smp.Unet('resnet34', encoder_depth=4)

🛠️ 安装

PyPI版本

$ pip install segmentation-models-pytorch

从源码获取的最新版本

$ pip install git+https://github.com/qubvel/segmentation_models.pytorch

🏆 该库赢得的竞赛

Segmentation Models 包在图像分割竞赛中得到广泛应用。 在这里 您可以找到竞赛、获胜者的名字和他们的解决方案的链接。

🤝 贡献

运行测试
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
生成表格
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py

📝 引用

@misc{Yakubovskiy:2019,
  Author = {Pavel Yakubovskiy},
  Title = {Segmentation Models Pytorch},
  Year = {2020},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}

🛡️ 许可证

项目是在MIT许可证下分发的

项目详情


下载文件

下载您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源代码发行版

d3m_segmentation_models_pytorch-0.1.3.tar.gz (40.0 kB 查看散列)

上传于 源代码

构建版本