跳转到主要内容

TorchSeg: PyTorch的语义分割模型

项目描述

TorchSeg

TorchSeg是Segmentation Models PyTorch (smp)库的活跃维护和更新分支。

更新

此分支的目标是 1) 为原始库提供维护支持,2) 添加与现代语义分割相关的功能。自分支以来,此库已添加一些功能,可总结如下

  • 改进了PyTorch Image Models (timm),用于具有特征提取功能的模型(852/1017=84%的timm模型)。这包括典型的CNN模型,如ResNetEfficientNet等,但现在扩展到包括现代架构,如ConvNextSwinPoolFormerMaxViT等!
  • 支持预训练的视觉Transformer (ViT) 编码器。目前timm ViT不支持开箱即用的特征提取。然而,我们已添加了对提取中间transformer编码器层特征图的支撑,以实现此功能。我们支持100多个基于ViT的模型,包括ViTDeiTFlexiViT

此外,我们还为提高软件标准执行了以下操作

  • 更彻底的测试和CI
  • 使用blackisortflake8mypy进行格式化
  • 减少对不再维护的库的依赖(现在只依赖于 torchtimmeinops
  • 减少维护的代码行数(移除了自定义的 utils、metrics 和 encoders),转而使用如 torchmetricstimm 的新库

特性

本库的主要特性包括

  • 高级 API(只需两行代码即可创建神经网络)
  • 9 种二值和多个类别的分割架构(包括 U-Net、DeepLabV3)
  • 支持来自 timm 的 852/1017(约 84%)的可用 encoders
  • 所有 encoders 都有预训练的权重,以实现更快的收敛和更好的效果
  • 流行的分割损失函数

示例用法

TorchSeg 模型在其基本形式上只是 torch nn.Modules。它们可以按以下方式创建

import torchseg

model = torchseg.Unet(
    encoder_name="resnet50",
    encoder_weights=True,
    in_channels=3
    classes=3,
)

TorchSeg 有一个 encoder_params 功能,在定义 encoders 背骨时将额外的参数传递给 timm.create_model()。可以指定不同的激活函数、归一化层等,如下所示。

您还可以定义一个 functools.partial 可调用对象作为激活/归一化层。有关更多信息,请参阅 timm 文档中的可用 激活函数归一化层。您甚至可以在更改激活/归一化的同时使用预训练的权重!

model = torchseg.Unet(
    encoder_name="resnet50",
    encoder_weights=True,
    in_channels=3
    classes=3,
    encoder_params={
      "act_layer": "prelu",
      "norm_layer": "layernorm"
    }
)

一些模型如 SwinConvNext 在第一个块(stem)中执行 scale=4 的下采样,然后使用仅 depth=4 的块以 2 倍的速率进行下采样。这导致解码器之后的输出大小减半。要获得与输入相同大小的输出,可以传递 head_upsampling=2,这将再次在上采样之前进行一次上采样。

model = torchseg.Unet(
    "convnextv2_tiny",
    in_channels=3,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=(256, 128, 64, 32),
    head_upsampling=2
)

model = torchseg.Unet(
    "swin_tiny_patch4_window7_224",
    in_channels=3,
    classes=2,
    encoder_weights=True,
    encoder_depth=4,
    decoder_channels=(256, 128, 64, 32),
    head_upsampling=2,
    encoder_params={"img_size": 256}  # need to define img size since swin is a ViT hybrid
)

model = torchseg.Unet(
    "maxvit_small_tf_224",
    in_channels=3,
    classes=2,
    encoder_weights=True,
    encoder_depth=5,
    decoder_channels=(256, 128, 64, 32, 16),
    encoder_params={"img_size": 256}
)

TorchSeg 通过提取由 encoder_indicesencoder_depth 参数指定的中间变压器块特征来支持 timm 的预训练 ViT encoders。

您还需要为上采样特征层到解码器预期的分辨率定义 scale_factors。对于 U-Net depth=5,这将是对 scales=(8, 4, 2, 1, 0.5)。对于 depth=4,这将是对 scales=(4, 2, 1, 0.5),对于 depth=3,这将是对 scales=(2, 1, 0.5),依此类推。

使用 timm 的另一个好处是,通过传递新的 img_size,timm 会自动插值 ViT 位置嵌入以适应您的新图像大小,这会创建不同数量的补丁标记。

import torch
import torchseg


model = torchseg.Unet(
    "vit_small_patch16_224",
    in_channels=8,
    classes=2,
    encoder_depth=5,
    encoder_indices=(2, 4, 6, 8, 10),  # which intermediate blocks to extract features from
    encoder_weights=True,
    decoder_channels=(256, 128, 64, 32, 16),
    encoder_params={  # additional params passed to timm.create_model and the vit encoder
        "scale_factors": (8, 4, 2, 1, 0.5), # resize scale_factors for patch size 16 and 5 layers
        "img_size": 256,  # timm automatically interpolates the positional embeddings to your new image size
    },
)

模型

架构(解码器)

编码器

TorchSeg 完全依赖于 timm 库来支持预训练的 encoders。这意味着 TorchSeg 支持任何具有 features_only 功能提取功能的 timm 模型。此外,我们支持任何具有 get_intermediate_layers 方法的 ViT 模型。这导致 timm 的总共有 852/1017(约 84%)的 encoders,包括 ResNetSwinConvNextViT 等!

以下列出以下支持的 encoders

import torchseg

torchseg.list_encoders()

我们还从 timm 中提取了具有 features_only 支持的每个模型的特征提取器元数据,在 output_stride=32。这些元数据提供了有关中间层数量、每层的通道数、层名称以及下采样减少的信息。

import torchseg

metadata = torchseg.encoders.TIMM_ENCODERS["convnext_base"]
print(metadata)

"""
{
   'channels': [128, 256, 512, 1024],
   'indices': (0, 1, 2, 3),
   'module': ['stages.0', 'stages.1', 'stages.2', 'stages.3'],
   'reduction': [4, 8, 16, 32],
}
"""

metadata = torchseg.encoders.TIMM_ENCODERS["resnet50"]
print(metadata)

"""
{
   'channels': [64, 256, 512, 1024, 2048],
   'indices': (0, 1, 2, 3, 4),
   'module': ['act1', 'layer1', 'layer2', 'layer3', 'layer4'],
   'reduction': [2, 4, 8, 16, 32]
}
"""

模型 API

  • model.encoder - 用于提取中间特征的预训练骨干
  • model.decoder - 处理中间特征到原始图像分辨率的网络(UnetDeepLabv3+FPN
  • model.segmentation_head - 生成掩码输出的最终块(包括可选的上采样和激活)
  • model.classification_head - 可选块,在编码器之上创建分类头
  • model.forward(x) - 依次通过模型编码器、解码器和分割头(如果指定,还包括分类头)传递x
输入通道

Timm编码器支持使用任意输入通道的预训练权重,如果输入通道数大于3,则通过重复权重来使用。例如,如果in_channels=6,初始层的RGB ImageNet预训练权重将重复为RGBRGB,以避免随机初始化。对于in_channels=7,这将导致RGBRGBR。下面是可视化此方法的图示。


辅助分类器

所有模型都支持使用aux_params使用可选的辅助分类器头。如果aux_params != None,则模型将生成除了mask输出外还有形状为(N, C)label输出。分类头由GlobalPooling->Dropout(可选)->Linear->Activation(可选)层组成,可以通过aux_params如下配置

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.5,               # dropout ratio, default is None
    activation=nn.Sigmoid(),   # activation function, default is Identity
    classes=4,                 # define number of output labels
)
model = torchseg.Unet('resnet18', classes=4, aux_params=aux_params)
mask, label = model(x)
深度

深度代表编码器中的下采样操作次数,因此可以通过指定更少的depth来使模型更轻。默认为depth=5

请注意,一些模型(如ConvNextSwin)只有4个中间特征块。因此,为了使用这些编码器,请设置encoder_depth=4。这可以在上面的元数据中找到。

model = torchseg.Unet('resnet50', encoder_depth=4)

项目详情


2024年1月25日

下载文件

下载适合您平台文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

torchseg-0.0.1a4.tar.gz (51.6 kB 查看哈希)

上传时间

构建版本

torchseg-0.0.1a4-py3-none-any.whl (67.9 kB 查看哈希)

上传时间 Python 3

由以下支持