基于预训练骨干的图像分割模型。PyTorch。
项目描述
Python库,包含用于图像的神经网络
基于PyTorch的分割。
该库的主要特性包括
- 高级API(仅两行即可创建神经网络)
- 8种模型架构用于二值和多类分割(包括传奇的Unet)
- 99种可用的编码器
- 所有编码器都有预训练的权重,以实现更快和更好的收敛
📚 项目文档 📚
访问Read The Docs项目页面或阅读以下README以了解更多关于Segmentation Models Pytorch(简称SMP)库的信息
📋 目录
⏳ 快速入门
1. 使用SMP创建您的第一个分割模型
分割模型只是一个PyTorch nn.Module,它可以像下面这样轻松创建
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pretreined weights for encoder initialization
in_channels=1, # model input channels (1 for grayscale images, 3 for RGB, etc.)
classes=3, # model output channels (number of classes in your dataset)
)
2. 配置数据预处理
所有编码器都有预训练的权重。与权重预训练期间相同的方式准备您的数据可能会给您带来更好的结果(更高的指标得分和更快的收敛)。但这只适用于1-2-3通道图像,并且对于您训练整个模型(而不仅仅是解码器),通常是不必要的。
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
恭喜!您已完成!现在您可以使用您喜欢的框架训练您的模型了!
💡 示例
- 在CamVid数据集上训练汽车分割模型这里。
- 使用Catalyst(PyTorch的高级框架)、Catalyst、TTAch(PyTorch的TTA库)和Albumentations(快速图像增强库)训练SMP模型 - 这里
📦 模型
架构
- Unet [论文] [文档]
- Unet++ [论文] [文档]
- Linknet [论文] [文档]
- FPN [论文] [文档]
- PSPNet [论文] [文档]
- PAN [论文] [文档]
- DeepLabV3 [论文] [文档]
- DeepLabV3+ [论文] [文档]
编码器
以下是在SMP中支持的编码器列表。选择适当的编码器系列,点击展开表格并选择特定的编码器和其预训练权重(encoder_name
和 encoder_weights
参数)。
ResNet
编码器 | 权重 | 参数,M |
---|---|---|
resnet18 | imagenet / ssl / swsl | 11M |
resnet34 | imagenet | 21M |
resnet50 | imagenet / ssl / swsl | 23M |
resnet101 | imagenet | 42M |
resnet152 | imagenet | 58M |
ResNeXt
编码器 | 权重 | 参数,M |
---|---|---|
resnext50_32x4d | imagenet / ssl / swsl | 22M |
resnext101_32x4d | ssl / swsl | 42M |
resnext101_32x8d | imagenet / instagram / ssl / swsl | 86M |
resnext101_32x16d | instagram / ssl / swsl | 191M |
resnext101_32x32d | 466M | |
resnext101_32x48d | 826M |
ResNeSt
编码器 | 权重 | 参数,M |
---|---|---|
timm-resnest14d | imagenet | 8M |
timm-resnest26d | imagenet | 15M |
timm-resnest50d | imagenet | 25M |
timm-resnest101e | imagenet | 46M |
timm-resnest200e | imagenet | 68M |
timm-resnest269e | imagenet | 108M |
timm-resnest50d_4s2x40d | imagenet | 28M |
timm-resnest50d_1s4x24d | imagenet | 23M |
Res2Ne(X)t
编码器 | 权重 | 参数,M |
---|---|---|
timm-res2net50_26w_4s | imagenet | 23M |
timm-res2net101_26w_4s | imagenet | 43M |
timm-res2net50_26w_6s | imagenet | 35M |
timm-res2net50_26w_8s | imagenet | 46M |
timm-res2net50_48w_2s | imagenet | 23M |
timm-res2net50_14w_8s | imagenet | 23M |
timm-res2next50 | imagenet | 22M |
RegNet(x/y)
编码器 | 权重 | 参数,M |
---|---|---|
timm-regnetx_002 | imagenet | 2M |
timm-regnetx_004 | imagenet | 4M |
timm-regnetx_006 | imagenet | 5M |
timm-regnetx_008 | imagenet | 6M |
timm-regnetx_016 | imagenet | 8M |
timm-regnetx_032 | imagenet | 14M |
timm-regnetx_040 | imagenet | 20M |
timm-regnetx_064 | imagenet | 24M |
timm-regnetx_080 | imagenet | 37M |
timm-regnetx_120 | imagenet | 43M |
timm-regnetx_160 | imagenet | 52M |
timm-regnetx_320 | imagenet | 105M |
timm-regnety_002 | imagenet | 2M |
timm-regnety_004 | imagenet | 3M |
timm-regnety_006 | imagenet | 5M |
timm-regnety_008 | imagenet | 5M |
timm-regnety_016 | imagenet | 10M |
timm-regnety_032 | imagenet | 17M |
timm-regnety_040 | imagenet | 19M |
timm-regnety_064 | imagenet | 29M |
timm-regnety_080 | imagenet | 37M |
timm-regnety_120 | imagenet | 49M |
timm-regnety_160 | imagenet | 80M |
timm-regnety_320 | imagenet | 141M |
SE-Net
编码器 | 权重 | 参数,M |
---|---|---|
senet154 | imagenet | 113M |
se_resnet50 | imagenet | 26M |
se_resnet101 | imagenet | 47M |
se_resnet152 | imagenet | 64M |
se_resnext50_32x4d | imagenet | 25M |
se_resnext101_32x4d | imagenet | 46M |
SK-ResNe(X)t
编码器 | 权重 | 参数,M |
---|---|---|
timm-skresnet18 | imagenet | 11M |
timm-skresnet34 | imagenet | 21M |
timm-skresnext50_32x4d | imagenet | 25M |
DenseNet
编码器 | 权重 | 参数,M |
---|---|---|
densenet121 | imagenet | 6M |
密集连接169 | imagenet | 12M |
密集连接201 | imagenet | 18M |
密集连接161 | imagenet | 26M |
Inception
编码器 | 权重 | 参数,M |
---|---|---|
inceptionresnetv2 | imagenet / imagenet+background | 54M |
inceptionv4 | imagenet / imagenet+background | 41M |
xception | imagenet | 22M |
EfficientNet
编码器 | 权重 | 参数,M |
---|---|---|
efficientnet-b0 | imagenet | 4M |
efficientnet-b1 | imagenet | 6M |
efficientnet-b2 | imagenet | 7M |
efficientnet-b3 | imagenet | 10M |
efficientnet-b4 | imagenet | 17M |
efficientnet-b5 | imagenet | 28M |
efficientnet-b6 | imagenet | 40M |
efficientnet-b7 | imagenet | 63M |
timm-efficientnet-b0 | imagenet / advprop / noisy-student | 4M |
timm-efficientnet-b1 | imagenet / advprop / noisy-student | 6M |
timm-efficientnet-b2 | imagenet / advprop / noisy-student | 7M |
timm-efficientnet-b3 | imagenet / advprop / noisy-student | 10M |
timm-efficientnet-b4 | imagenet / advprop / noisy-student | 17M |
timm-efficientnet-b5 | imagenet / advprop / noisy-student | 28M |
timm-efficientnet-b6 | imagenet / advprop / noisy-student | 40M |
timm-efficientnet-b7 | imagenet / advprop / noisy-student | 63M |
timm-efficientnet-b8 | imagenet / advprop | 84M |
timm-efficientnet-l2 | noisy-student | 474M |
MobileNet
编码器 | 权重 | 参数,M |
---|---|---|
mobilenet_v2 | imagenet | 2M |
DPN
编码器 | 权重 | 参数,M |
---|---|---|
dpn68 | imagenet | 11M |
dpn68b | imagenet+5k | 11M |
dpn92 | imagenet+5k | 34M |
dpn98 | imagenet | 58M |
dpn107 | imagenet+5k | 84M |
dpn131 | imagenet | 76M |
VGG
编码器 | 权重 | 参数,M |
---|---|---|
vgg11 | imagenet | 9M |
vgg11_bn | imagenet | 9M |
vgg13 | imagenet | 9M |
vgg13_bn | imagenet | 9M |
vgg16 | imagenet | 14M |
vgg16_bn | imagenet | 14M |
vgg19 | imagenet | 20M |
vgg19_bn | imagenet | 20M |
* ssl
, swsl
- 在ImageNet上的半监督和弱监督学习 (repo).
🔁 模型API
model.encoder
- 预训练骨干网络,用于提取不同空间分辨率的特征model.decoder
- 根据模型架构(Unet
/Linknet
/PSPNet
/FPN
)而定model.segmentation_head
- 最后的块,用于生成所需数量的掩码通道(包括可选的上采样和激活)model.classification_head
- 可选的块,在编码器顶部创建分类头model.forward(x)
- 依次将x
通过模型的编码器、解码器和分割头(以及指定的分类头)
输入通道
输入通道参数允许您创建处理具有任意数量通道的张量的模型。如果您使用从imagenet预训练的权重,则第一个卷积层的权重将被重用于1或2个通道的输入,对于输入通道>4,第一个卷积层的权重将被随机初始化。
model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
辅助分类输出
所有模型都支持aux_params
参数,默认设置为None
。如果aux_params = None
,则不创建分类辅助输出,否则模型将不仅生成mask
输出,还生成形状为NC
的label
输出。分类头由GlobalPooling->Dropout(可选)->Linear->Activation(可选)层组成,可以通过aux_params
进行如下配置
aux_params=dict(
pooling='avg', # one of 'avg', 'max'
dropout=0.5, # dropout ratio, default is None
activation='sigmoid', # activation function, default is None
classes=4, # define number of output labels
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)
深度
深度参数指定编码器中下采样操作的数量,因此您可以通过指定较小的depth
来使模型变轻。
model = smp.Unet('resnet34', encoder_depth=4)
🛠️ 安装
PyPI版本
$ pip install segmentation-models-pytorch
从源码获取的最新版本
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
🏆 该库赢得的竞赛
Segmentation Models
包在图像分割竞赛中得到广泛应用。 在这里 您可以找到竞赛、获胜者的名字和他们的解决方案的链接。
🤝 贡献
运行测试
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
生成表格
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py
📝 引用
@misc{Yakubovskiy:2019,
Author = {Pavel Yakubovskiy},
Title = {Segmentation Models Pytorch},
Year = {2020},
Publisher = {GitHub},
Journal = {GitHub repository},
Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}
🛡️ 许可证
项目是在MIT许可证下分发的
项目详情
d3m_segmentation_models_pytorch-0.1.3.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 1b5e61d7e93ff3fc61cd9ab7b977676541b9c117af4b426a24fd76b520250831 |
|
MD5 | 4e5cb0dbd5f564367a4d4c7f00b7e05f |
|
BLAKE2b-256 | bcd40737266997d6f9b104d1548b5bdd088e29b8603b7c9e82f25891447514c4 |
d3m_segmentation_models_pytorch-0.1.3-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | b4d57f5b62725fb2c48d1db2e78c4ddd5df55ecd405390f45261b3bfbdf7e6e5 |
|
MD5 | 5d4ced09a880ebc5e2b8b3d9f79d05ca |
|
BLAKE2b-256 | f0f07fddd77c7a23ad31206e5bab1cbd8e65c88c268bbb24dd30c5876667d16c |