TorchSeg: PyTorch的语义分割模型
项目描述
TorchSeg
TorchSeg是Segmentation Models PyTorch (smp)库的活跃维护和更新分支。
更新
此分支的目标是 1) 为原始库提供维护支持,2) 添加与现代语义分割相关的功能。自分支以来,此库已添加一些功能,可总结如下
- 改进了PyTorch Image Models (timm),用于具有特征提取功能的模型(852/1017=84%的timm模型)。这包括典型的CNN模型,如
ResNet
、EfficientNet
等,但现在扩展到包括现代架构,如ConvNext
、Swin
、PoolFormer
、MaxViT
等! - 支持预训练的视觉Transformer (ViT) 编码器。目前timm ViT不支持开箱即用的特征提取。然而,我们已添加了对提取中间transformer编码器层特征图的支撑,以实现此功能。我们支持100多个基于ViT的模型,包括
ViT
、DeiT
、FlexiViT
!
此外,我们还为提高软件标准执行了以下操作
- 更彻底的测试和CI
- 使用
black
、isort
、flake8
、mypy
进行格式化 - 减少对不再维护的库的依赖(现在只依赖于
torch
、timm
和einops
) - 减少维护的代码行数(移除了自定义的 utils、metrics 和 encoders),转而使用如
torchmetrics
和timm
的新库
特性
本库的主要特性包括
- 高级 API(只需两行代码即可创建神经网络)
- 9 种二值和多个类别的分割架构(包括 U-Net、DeepLabV3)
- 支持来自 timm 的 852/1017(约 84%)的可用 encoders
- 所有 encoders 都有预训练的权重,以实现更快的收敛和更好的效果
- 流行的分割损失函数
示例用法
TorchSeg 模型在其基本形式上只是 torch nn.Modules。它们可以按以下方式创建
import torchseg
model = torchseg.Unet(
encoder_name="resnet50",
encoder_weights=True,
in_channels=3
classes=3,
)
TorchSeg 有一个 encoder_params
功能,在定义 encoders 背骨时将额外的参数传递给 timm.create_model()
。可以指定不同的激活函数、归一化层等,如下所示。
您还可以定义一个 functools.partial
可调用对象作为激活/归一化层。有关更多信息,请参阅 timm 文档中的可用 激活函数 和 归一化层。您甚至可以在更改激活/归一化的同时使用预训练的权重!
model = torchseg.Unet(
encoder_name="resnet50",
encoder_weights=True,
in_channels=3
classes=3,
encoder_params={
"act_layer": "prelu",
"norm_layer": "layernorm"
}
)
一些模型如 Swin
和 ConvNext
在第一个块(stem)中执行 scale=4 的下采样,然后使用仅 depth=4
的块以 2 倍的速率进行下采样。这导致解码器之后的输出大小减半。要获得与输入相同大小的输出,可以传递 head_upsampling=2
,这将再次在上采样之前进行一次上采样。
model = torchseg.Unet(
"convnextv2_tiny",
in_channels=3,
classes=2,
encoder_weights=True,
encoder_depth=4,
decoder_channels=(256, 128, 64, 32),
head_upsampling=2
)
model = torchseg.Unet(
"swin_tiny_patch4_window7_224",
in_channels=3,
classes=2,
encoder_weights=True,
encoder_depth=4,
decoder_channels=(256, 128, 64, 32),
head_upsampling=2,
encoder_params={"img_size": 256} # need to define img size since swin is a ViT hybrid
)
model = torchseg.Unet(
"maxvit_small_tf_224",
in_channels=3,
classes=2,
encoder_weights=True,
encoder_depth=5,
decoder_channels=(256, 128, 64, 32, 16),
encoder_params={"img_size": 256}
)
TorchSeg 通过提取由 encoder_indices
和 encoder_depth
参数指定的中间变压器块特征来支持 timm 的预训练 ViT encoders。
您还需要为上采样特征层到解码器预期的分辨率定义 scale_factors
。对于 U-Net depth=5
,这将是对 scales=(8, 4, 2, 1, 0.5)
。对于 depth=4
,这将是对 scales=(4, 2, 1, 0.5)
,对于 depth=3
,这将是对 scales=(2, 1, 0.5)
,依此类推。
使用 timm 的另一个好处是,通过传递新的 img_size
,timm 会自动插值 ViT 位置嵌入以适应您的新图像大小,这会创建不同数量的补丁标记。
import torch
import torchseg
model = torchseg.Unet(
"vit_small_patch16_224",
in_channels=8,
classes=2,
encoder_depth=5,
encoder_indices=(2, 4, 6, 8, 10), # which intermediate blocks to extract features from
encoder_weights=True,
decoder_channels=(256, 128, 64, 32, 16),
encoder_params={ # additional params passed to timm.create_model and the vit encoder
"scale_factors": (8, 4, 2, 1, 0.5), # resize scale_factors for patch size 16 and 5 layers
"img_size": 256, # timm automatically interpolates the positional embeddings to your new image size
},
)
模型
架构(解码器)
- Unet [论文]
- Unet++ [论文]
- MAnet [论文]
- Linknet [论文]
- FPN [论文]
- PSPNet [论文]
- PAN [论文]
- DeepLabV3 [论文]
- DeepLabV3+ [论文]
编码器
TorchSeg 完全依赖于 timm 库来支持预训练的 encoders。这意味着 TorchSeg 支持任何具有 features_only
功能提取功能的 timm 模型。此外,我们支持任何具有 get_intermediate_layers
方法的 ViT 模型。这导致 timm 的总共有 852/1017(约 84%)的 encoders,包括 ResNet
、Swin
、ConvNext
、ViT
等!
以下列出以下支持的 encoders
import torchseg
torchseg.list_encoders()
我们还从 timm 中提取了具有 features_only
支持的每个模型的特征提取器元数据,在 output_stride=32
。这些元数据提供了有关中间层数量、每层的通道数、层名称以及下采样减少的信息。
import torchseg
metadata = torchseg.encoders.TIMM_ENCODERS["convnext_base"]
print(metadata)
"""
{
'channels': [128, 256, 512, 1024],
'indices': (0, 1, 2, 3),
'module': ['stages.0', 'stages.1', 'stages.2', 'stages.3'],
'reduction': [4, 8, 16, 32],
}
"""
metadata = torchseg.encoders.TIMM_ENCODERS["resnet50"]
print(metadata)
"""
{
'channels': [64, 256, 512, 1024, 2048],
'indices': (0, 1, 2, 3, 4),
'module': ['act1', 'layer1', 'layer2', 'layer3', 'layer4'],
'reduction': [2, 4, 8, 16, 32]
}
"""
模型 API
model.encoder
- 用于提取中间特征的预训练骨干model.decoder
- 处理中间特征到原始图像分辨率的网络(Unet
、DeepLabv3+
、FPN
)model.segmentation_head
- 生成掩码输出的最终块(包括可选的上采样和激活)model.classification_head
- 可选块,在编码器之上创建分类头model.forward(x)
- 依次通过模型编码器、解码器和分割头(如果指定,还包括分类头)传递x
输入通道
Timm编码器支持使用任意输入通道的预训练权重,如果输入通道数大于3,则通过重复权重来使用。例如,如果in_channels=6
,初始层的RGB ImageNet预训练权重将重复为RGBRGB
,以避免随机初始化。对于in_channels=7
,这将导致RGBRGBR
。下面是可视化此方法的图示。
辅助分类器
所有模型都支持使用aux_params
使用可选的辅助分类器头。如果aux_params != None
,则模型将生成除了mask
输出外还有形状为(N, C)
的label
输出。分类头由GlobalPooling->Dropout(可选)->Linear->Activation(可选)层组成,可以通过aux_params
如下配置
aux_params=dict(
pooling='avg', # one of 'avg', 'max'
dropout=0.5, # dropout ratio, default is None
activation=nn.Sigmoid(), # activation function, default is Identity
classes=4, # define number of output labels
)
model = torchseg.Unet('resnet18', classes=4, aux_params=aux_params)
mask, label = model(x)
深度
深度代表编码器中的下采样操作次数,因此可以通过指定更少的depth
来使模型更轻。默认为depth=5
。
请注意,一些模型(如ConvNext
和Swin
)只有4个中间特征块。因此,为了使用这些编码器,请设置encoder_depth=4
。这可以在上面的元数据中找到。
model = torchseg.Unet('resnet50', encoder_depth=4)
项目详情
2024年1月25日
下载文件