MSAdapter是一个工具包,用于支持在Ascend上运行的PyTorch模型。
项目描述
简介
MSAdapter是MindSpore工具,用于适配PyTorch接口,旨在使PyTorch代码在Ascend上高效运行,同时不改变原始PyTorch用户的使用习惯。
安装
MSAdapter有一些先决条件需要安装,包括MindSpore、PIL、NumPy。
# for last stable version
pip install msadapter
# for latest release candidate
pip install --upgrade --pre msadapter
或者,您可以直接从OpenI拉取最新或开发版本进行安装
pip3 install git+https://openi.pcl.ac.cn/OpenI/MSAdapter.git
用户指南
对于数据处理和模型构建,MSAdapter可以使用与PyTorch相同的方式,而代码中的模型训练部分需要定制,如下面的示例所示。
数据处理(仅修改导入包)
from msadapter.pytorch.utils.data import DataLoader
from msadapter.torchvision import datasets, transforms
transform = transforms.Compose([transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616])
])
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
train_data = DataLoader(train_images, batch_size=128, shuffle=True, num_workers=2, drop_last=True)
模型构建(仅修改导入包)
from msadapter.pytorch.nn import Module, Linear, Flatten
class MLP(Module):
def __init__(self):
super(MLP, self).__init__()
self.flatten = Flatten()
self.line1 = Linear(in_features=1024, out_features=64)
self.line2 = Linear(in_features=64, out_features=128, bias=False)
self.line3 = Linear(in_features=128, out_features=10)
def forward(self, inputs):
x = self.flatten(inputs)
x = self.line1(x)
x = self.line2(x)
x = self.line3(x)
return x
3. 模型训练(自定义训练)
import msadapter.pytorch as torch
import msadapter.pytorch.nn as nn
import mindspore as ms
net = MLP()
net.train()
epochs = 500
criterion = nn.CrossEntropyLoss()
optimizer = ms.nn.SGD(net.trainable_params(), learning_rate=0.01, momentum=0.9, weight_decay=0.0005)
# Define the training process
loss_net = ms.nn.WithLossCell(net, criterion)
train_net = ms.nn.TrainOneStepCell(loss_net, optimizer)
for i in range(epochs):
for X, y in train_data:
res = train_net(X, y)
print("epoch:{}, loss:{:.6f}".format(i, res.asnumpy()))
# Save model
ms.save_checkpoint(net, "save_path.ckpt")
许可证
MSAdapter采用Apache 2.0许可证发布。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。
源分布
msadapter-0.1.0.tar.gz (621.2 kB 查看哈希值)
构建分发
msadapter-0.1.0-py3-none-any.whl (812.7 kB 查看哈希值)
关闭
msadapter-0.1.0.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 8d3d2aa49450d5effe92efbe3a6a579228beba1bb27e876dff7405e6035cb93e |
|
MD5 | 29d59cad094b7be4656d5052cc2551d3 |
|
BLAKE2b-256 | 9933289bf245c2d680dde0e0ab9b00c5e1e0b3b4a5d6c39694790526f4669a43 |
关闭
msadapter-0.1.0-py3-none-any.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 7906538f712b72c78ecdd40b2b4c4c4bbf59228f0d5fc90b380e90159c676b55 |
|
MD5 | 8c36ea061d402fae9e2632fbeb106f6d |
|
BLAKE2b-256 | 81d35459b06f1ea941e2a786463d0d706ab440cf1d6d31fa077c552ce702daa3 |