TorchGeo:地理空间数据的集、采样器、转换和预训练模型
项目描述
TorchGeo 是一个与 PyTorch 相似的领域库,类似于 torchvision,提供针对地理空间数据的特定数据集、采样器、转换和预训练模型。
本库的目标是使其变得非常简单
- 以便机器学习专家能够处理地理空间数据,并且
- 以便遥感专家探索机器学习解决方案。
安装
推荐使用 pip 安装 TorchGeo
$ pip install torchgeo
文档
您可以在 ReadTheDocs 上找到 TorchGeo 的文档。这包括 API 文档、贡献说明和多个 教程。有关更多详细信息,请参阅我们的 论文、播客节目、教程和博客文章。
示例用法
以下部分提供了使用 TorchGeo 可以执行的基本示例。
首先,我们将导入以下部分中使用的各种类和函数
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.datasets import CDL, Landsat7, Landsat8, VHR10, stack_samples
from torchgeo.samplers import RandomGeoSampler
from torchgeo.trainers import SemanticSegmentationTask
地理空间数据集和采样器
许多遥感应用都涉及处理具有地理元数据的 地理空间数据集——这些数据集的数据种类繁多,可能具有挑战性。地理空间图像通常是多光谱的,每个卫星的谱带数量和空间分辨率都不同。此外,每个文件可能位于不同的坐标参考系统(CRS)中,需要将数据重新投影到匹配的 CRS。
在这个例子中,我们展示了如何使用TorchGeo轻松地处理地理空间数据,并从Landsat和Cropland Data Layer (CDL)数据组合中采样小图像块。首先,我们假设用户已经下载了Landsat 7和8的图像。由于Landsat 8的谱段比Landsat 7多,我们只会使用两个卫星共有的谱段。我们将创建一个包含Landsat 7和8所有图像的单个数据集,通过这两个数据集的并集来实现。
landsat7 = Landsat7(root="...", bands=["B1", ..., "B7"])
landsat8 = Landsat8(root="...", bands=["B2", ..., "B8"])
landsat = landsat7 | landsat8
接下来,我们将该数据集与CDL数据集进行交集运算。我们采用交集而不是并集,以确保我们只从具有Landsat和CDL数据的区域进行采样。请注意,我们可以自动下载并校验CDL数据。另外,请注意,这些数据集可能包含不同坐标参考系统(CRS)或分辨率的文件,但TorchGeo自动确保使用匹配的CRS和分辨率。
cdl = CDL(root="...", download=True, checksum=True)
dataset = landsat & cdl
现在可以使用PyTorch数据加载器使用此数据集。与基准数据集不同,地理空间数据集通常包含非常大的图像。例如,CDL数据集由一张覆盖整个美国大陆的单一图像组成。为了使用地理坐标从这些数据集中采样,TorchGeo定义了多个采样器。在这个例子中,我们将使用一个随机采样器,返回256 x 256像素的图像,并在每个epoch中采样10,000个样本。我们还使用自定义的收集函数将每个样本字典组合成一个样本的小批量。
sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)
现在可以将此数据加载器用于您正常的训练/评估流程。
for batch in dataloader:
image = batch["image"]
mask = batch["mask"]
# train a model, or make predictions using a pre-trained model
许多应用涉及根据此类地理空间元数据智能地组合数据集。例如,用户可能想要
- 将来自多个图像来源的数据集合并,并将它们视为等效的(例如,Landsat 7和8)
- 将不同地理空间位置的数据集合并(例如,切萨皮克NY和PA)
这些组合要求所有查询至少存在于一个数据集中,可以使用UnionDataset
创建。同样,用户可能想要
- 将图像和目标标签合并,并同时从两者中采样(例如,Landsat和CDL)
- 为多模态学习或数据融合合并多个图像来源的数据集(例如,Landsat和Sentinel)
这些组合要求所有查询都存在于两个数据集中,可以使用IntersectionDataset
创建。当您使用交集(&
)和并集(|
)运算符时,TorchGeo会自动为您组合这些数据集。
基准数据集
TorchGeo包括多个基准数据集——包含输入图像和目标标签的数据集。这包括用于图像分类、回归、语义分割、目标检测、实例分割、变化检测等任务的数据库。
如果您之前使用过torchvision,这些数据集应该非常熟悉。在这个例子中,我们将为西北工业大学(NWPU)的高分辨率十类(VHR-10)地理空间目标检测数据集创建一个数据集。该数据集可以像torchvision一样自动下载、校验和提取。
from torch.utils.data import DataLoader
from torchgeo.datamodules.utils import collate_fn_detection
from torchgeo.datasets import VHR10
# Initialize the dataset
dataset = VHR10(root="...", download=True, checksum=True)
# Initialize the dataloader with the custom collate function
dataloader = DataLoader(
dataset,
batch_size=128,
shuffle=True,
num_workers=4,
collate_fn=collate_fn_detection,
)
# Training loop
for batch in dataloader:
images = batch["image"] # list of images
boxes = batch["boxes"] # list of boxes
labels = batch["labels"] # list of labels
masks = batch["masks"] # list of masks
# train a model, or make predictions using a pre-trained model
TorchGeo的所有数据集都与PyTorch数据加载器兼容,这使得它们很容易集成到现有的训练工作流程中。TorchGeo中的基准数据集与torchvision中的类似数据集之间唯一的区别是,每个数据集返回一个包含每个PyTorch Tensor
键的字典。
预训练权重
预训练权重已被证明在计算机视觉的迁移学习任务中具有极大的益处。从业者通常使用在ImageNet数据集上预训练的模型,该数据集包含RGB图像。然而,遥感数据通常包含RGB图像之外的额外多光谱通道,这些通道在不同传感器中可能有所不同。TorchGeo是第一个支持在不同多光谱传感器上预训练的模型的库,并采用了torchvision的多权重API。当前可用的权重总结可以在文档中查看。要创建一个在Sentinel-2图像上预训练的timm Resnet-18模型,你可以这样做
import timm
from torchgeo.models import ResNet18_Weights
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
这些权重也可以通过下面的weights
参数直接用于本节中展示的TorchGeo Lightning模块。有关笔记本示例,请参阅这个教程。
使用Lightning实现可重复性
为了方便在文献中发布的成果之间进行直接比较,并进一步减少在TorchGeo数据集上运行实验所需的样板代码,我们创建了具有明确定义的训练-验证-测试划分的Lightning datamodules和针对各种任务(如分类、回归和语义分割)的trainers。这些datamodules展示了如何结合kornia库的增强,包括预处理转换(带有预计算的通道统计信息),并允许用户轻松地实验与数据本身相关的超参数(而不是建模过程)。在Inria Aerial Image Labeling数据集上训练语义分割模型就像导入几个模块和四行代码那么简单。
datamodule = InriaAerialImageLabelingDataModule(root="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(
model="unet",
backbone="resnet50",
weights=True,
in_channels=3,
num_classes=2,
loss="ce",
ignore_index=None,
lr=0.1,
patience=6,
)
trainer = Trainer(default_root_dir="...")
trainer.fit(model=task, datamodule=datamodule)
TorchGeo还支持使用LightningCLI的命令行界面训练。它可以以两种方式调用
# If torchgeo has been installed
torchgeo
# If torchgeo has been installed, or if it has been cloned to the current directory
python3 -m torchgeo
它支持命令行配置或YAML/JSON配置文件。有效选项可以在帮助消息中找到
# See valid stages
torchgeo --help
# See valid trainer options
torchgeo fit --help
# See valid model options
torchgeo fit --model.help ClassificationTask
# See valid data options
torchgeo fit --data.help EuroSAT100DataModule
使用以下配置文件
trainer:
max_epochs: 20
model:
class_path: ClassificationTask
init_args:
model: "resnet18"
in_channels: 13
num_classes: 10
data:
class_path: EuroSAT100DataModule
init_args:
batch_size: 8
dict_kwargs:
download: true
我们可以看到脚本的运行情况
# Train and validate a model
torchgeo fit --config config.yaml
# Validate-only
torchgeo validate --config config.yaml
# Calculate and report test accuracy
torchgeo test --config config.yaml --ckpt_path=...
如果你需要将其扩展以添加新功能,也可以将其导入到Python脚本中
from torchgeo.main import main
main(["fit", "--config", "config.yaml"])
有关详细信息,请参阅Lightning文档。
引用
如果你在本工作中使用了此软件,请引用我们的论文
@inproceedings{Stewart_TorchGeo_Deep_Learning_2022,
address = {Seattle, Washington},
author = {Stewart, Adam J. and Robinson, Caleb and Corley, Isaac A. and Ortiz, Anthony and Lavista Ferres, Juan M. and Banerjee, Arindam},
booktitle = {Proceedings of the 30th International Conference on Advances in Geographic Information Systems},
doi = {10.1145/3557915.3560953},
month = nov,
pages = {1--12},
publisher = {Association for Computing Machinery},
series = {SIGSPATIAL '22},
title = {{TorchGeo}: Deep Learning With Geospatial Data},
url = {https://dl.acm.org/doi/10.1145/3557915.3560953},
year = {2022}
}
贡献
本项目欢迎贡献和建议。如果你想提交pull request,请参阅我们的贡献指南以获取更多信息。
本项目已采用Microsoft Open Source Code of Conduct。有关更多信息,请参阅Code of Conduct FAQ或联系opencode@microsoft.com以获取任何额外的问题或评论。
项目详情
下载文件
下载适合您平台的文件。如果您不确定要选择哪一个,请了解有关安装包的更多信息。
源分布
构建版本
torchgeo-0.6.0.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | c5b073b3c9ac06cd68e45620bab3a78fb7637fa3563aae4f75f4781ba57aee5a |
|
MD5 | ab23b10e6f54fb0596136f46a6c5e765 |
|
BLAKE2b-256 | 8d3d6afd8f2e13ba938b5f5cec342eb921b6c02858e9f9d9b5eae7d860345ceb |
torchgeo-0.6.0-py3-none-any.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 41e9c0a8b56a28f9b81717fe02e7d33c13d65ba8e6de67a44d0800b3396e2227 |
|
MD5 | c9f361ab8af2375f4c486682d5f317da |
|
BLAKE2b-256 | d6b916346e153ba04c9e916e5081eff5f5f0d27865ece50f01baffb344317252 |