Skip to main content

MSAdapter is a toolkit for support the PyTorch model running on Ascend.

Project description

Introduction

MSAdapter is MindSpore tool for adapting the PyTorch interface, which is designed to make PyTorch code perform efficiently on Ascend without changing the habits of the original PyTorch users.

MSAdapter-architecture

Install

MSAdapter has some prerequisites that need to be installed first, including MindSpore, PIL, NumPy.

# for last stable version
pip install msadapter

# for latest release candidate
pip install --upgrade --pre msadapter

Alternatively, you can install the latest or development version by directly pulling from OpenI:

pip3 install git+https://openi.pcl.ac.cn/OpenI/MSAdapter.git

User guide

For data processing and model building, MSAdapter can be used in the same way as PyTorch, while the model training part of the code needs to be customized, as shown in the following example.

  1. Data processing (only modify the import package)

from msadapter.pytorch.utils.data import DataLoader
from msadapter.torchvision import datasets, transforms

transform = transforms.Compose([transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616])
                               ])
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
train_data = DataLoader(train_images, batch_size=128, shuffle=True, num_workers=2, drop_last=True)
  1. Model construction (modify import package only)

from msadapter.pytorch.nn import Module, Linear, Flatten

class MLP(Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = Flatten()
        self.line1 = Linear(in_features=1024, out_features=64)
        self.line2 = Linear(in_features=64, out_features=128, bias=False)
        self.line3 = Linear(in_features=128, out_features=10)

    def forward(self, inputs):
        x = self.flatten(inputs)
        x = self.line1(x)
        x = self.line2(x)
        x = self.line3(x)
        return x

3.Model training (custom training)

import msadapter.pytorch as torch
import msadapter.pytorch.nn as nn
import mindspore as ms

net = MLP()
net.train()
epochs = 500
criterion = nn.CrossEntropyLoss()
optimizer = ms.nn.SGD(net.trainable_params(), learning_rate=0.01, momentum=0.9, weight_decay=0.0005)

# Define the training process
loss_net = ms.nn.WithLossCell(net, criterion)
train_net = ms.nn.TrainOneStepCell(loss_net, optimizer)

for i in range(epochs):
    for X, y in train_data:
        res = train_net(X, y)
        print("epoch:{}, loss:{:.6f}".format(i, res.asnumpy()))
# Save model
ms.save_checkpoint(net, "save_path.ckpt")

License

MSAdapter is released under the Apache 2.0 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

msadapter-0.1.0.tar.gz (621.2 kB view details)

Uploaded Source

Built Distribution

msadapter-0.1.0-py3-none-any.whl (812.7 kB view details)

Uploaded Python 3

File details

Details for the file msadapter-0.1.0.tar.gz.

File metadata

  • Download URL: msadapter-0.1.0.tar.gz
  • Upload date:
  • Size: 621.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Python-urllib/3.6

File hashes

Hashes for msadapter-0.1.0.tar.gz
Algorithm Hash digest
SHA256 8d3d2aa49450d5effe92efbe3a6a579228beba1bb27e876dff7405e6035cb93e
MD5 29d59cad094b7be4656d5052cc2551d3
BLAKE2b-256 9933289bf245c2d680dde0e0ab9b00c5e1e0b3b4a5d6c39694790526f4669a43

See more details on using hashes here.

Provenance

File details

Details for the file msadapter-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: msadapter-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 812.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Python-urllib/3.6

File hashes

Hashes for msadapter-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7906538f712b72c78ecdd40b2b4c4c4bbf59228f0d5fc90b380e90159c676b55
MD5 8c36ea061d402fae9e2632fbeb106f6d
BLAKE2b-256 81d35459b06f1ea941e2a786463d0d706ab440cf1d6d31fa077c552ce702daa3

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page