一个提供创建新指标简单接口以及用于指标计算和检查点的易用工具包的库。
项目描述
TorchEval
此库目前处于Alpha阶段,目前没有稳定版本。API可能会更改,并且可能不向后兼容。如果您有改进建议,请打开GitHub问题。我们很乐意听取您的反馈。
一个包含丰富的PyTorch模型指标的库,提供创建新指标的简单接口,用于在分布式训练中方便地进行指标计算的工具包以及PyTorch模型评估的工具。
安装TorchEval
需要Python >= 3.8和PyTorch >= 1.11
从pip
pip install torcheval
对于夜间构建版本
pip install --pre torcheval-nightly
从源代码
git clone https://github.com/pytorch/torcheval
cd torcheval
pip install -r requirements.txt
python setup.py install
快速入门
在示例目录中还有更多示例
cd torcheval
python examples/simple_example.py
文档
文档可以在pytorch.org/torcheval找到
使用TorchEval
TorchEval可以在CPU、GPU以及多进程或多GPU环境下运行。提供了两种接口来获取指标,分别是函数式和基于类。函数式接口位于torcheval.metrics.functional
,适用于单进程环境。若要使用多进程或多GPU配置,可以在torcheval.metrics
中找到基于类的接口,这提供了更简洁的体验。基于类的接口还允许你在调用compute()
之前多次调用update()
来推迟一些指标的计算。即使在单进程环境中,这也可以通过减少计算开销而具有优势。
单进程
对于单进程程序的使用,最简单的用例是使用函数式指标。我们只需导入指标函数,并将输出和目标输入。下面的例子展示了PyTorch训练循环的简化版本,它评估了每第四批数据的多元分类准确率。
函数式版本(立即计算指标)
import torch
from torcheval.metrics.functional import multiclass_accuracy
NUM_BATCHES = 16
BATCH_SIZE = 8
INPUT_SIZE = 10
NUM_CLASSES = 6
eval_frequency = 4
model = torch.nn.Sequential(torch.nn.Linear(INPUT_SIZE, NUM_CLASSES), torch.nn.ReLU())
optim = torch.optim.Adagrad(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
metric_history = []
for batch in range(NUM_BATCHES):
input = torch.rand(size=(BATCH_SIZE, INPUT_SIZE))
target = torch.randint(size=(BATCH_SIZE,), high=NUM_CLASSES)
outputs = model(input)
loss = loss_fn(outputs, target)
optim.zero_grad()
loss.backward()
optim.step()
# metric only computed every 4 batches,
# data from previous three batches is lost
if (batch + 1) % eval_frequency == 0:
metric_history.append(multiclass_accuracy(outputs, target))
单进程延迟计算
类版本(允许延迟计算指标)
import torch
from torcheval.metrics import MulticlassAccuracy
NUM_BATCHES = 16
BATCH_SIZE = 8
INPUT_SIZE = 10
NUM_CLASSES = 6
eval_frequency = 4
model = torch.nn.Sequential(torch.nn.Linear(INPUT_SIZE, NUM_CLASSES), torch.nn.ReLU())
optim = torch.optim.Adagrad(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
metric = MulticlassAccuracy()
metric_history = []
for batch in range(NUM_BATCHES):
input = torch.rand(size=(BATCH_SIZE, INPUT_SIZE))
target = torch.randint(size=(BATCH_SIZE,), high=NUM_CLASSES)
outputs = model(input)
loss = loss_fn(outputs, target)
optim.zero_grad()
loss.backward()
optim.step()
# metric only computed every 4 batches,
# data from previous three batches is included
metric.update(input, target)
if (batch + 1) % eval_frequency == 0:
metric_history.append(metric.compute())
# remove old data so that the next call
# to compute is only based off next 4 batches
metric.reset()
多进程或多GPU
以下是一个在多个设备上使用的最小示例。在正常的torch.distributed
范式下,每个设备分配到自己的进程,并获得一个唯一的数值ID,称为“全局秩”,从0开始计数。
类版本(允许延迟计算和多进程)
import torch
from torcheval.metrics.toolkit import sync_and_compute
from torcheval.metrics import MulticlassAccuracy
# Using torch.distributed
local_rank = int(os.environ["LOCAL_RANK"]) #rank on local machine, i.e. unique ID within a machine
global_rank = int(os.environ["RANK"]) #rank in global pool, i.e. unique ID within the entire process group
world_size = int(os.environ["WORLD_SIZE"]) #total number of processes or "ranks" in the entire process group
device = torch.device(
f"cuda:{local_rank}"
if torch.cuda.is_available() and torch.cuda.device_count() >= world_size
else "cpu"
)
metric = MulticlassAccuracy(device=device)
num_epochs, num_batches = 4, 8
for epoch in range(num_epochs):
for i in range(num_batches):
input = torch.randint(high=5, size=(10,), device=device)
target = torch.randint(high=5, size=(10,), device=device)
# Add data to metric locally
metric.update(input, target)
# metric.compute() will returns metric value from
# all seen data on the local process since last reset()
local_compute_result = metric.compute()
# sync_and_compute(metric) syncs metric data across all ranks and computes the metric value
global_compute_result = sync_and_compute(metric)
if global_rank == 0:
print(global_compute_result)
# metric.reset() clears the data on each process so that subsequent
# calls to compute() only act on new data
metric.reset()
更多示例请参阅示例目录。
贡献
我们欢迎PR!请参阅CONTRIBUTING文件。
许可证
TorchEval遵循BSD许可,如LICENSE文件所示。
项目详情
下载文件
为您的平台下载文件。如果您不确定选择哪个,请了解更多关于安装包的信息。