Skip to main content

昇腾快速迁移适配包

Project description

npu-adapter

本项目主要是面向大模型昇腾适配过程中,针对场景的torch层算子不兼容,从GPU->NPU适配显存卸载,长序列切分等等能力的昇腾支持,加速整体迁移效率。当前覆盖面暂未包含所有场景,欢迎贡献代码,逐步的沉淀各个场景适配过程中的通用能力,构建更好用的昇腾生态。

项目信息

依赖项

  • torch>=2.8.0
  • torch_npu>=2.8.0
  • yunchang>=0.6.3.post1

安装

pip install npu-adapter

快速开始

import npu_adapter

# 适配到加速设备
npu_adapter.adapt_to_accelerator_device()

# 初始化分布式训练
npu_adapter.init_distributed_adapter()

功能模块

1. 设备检测与适配

has_npu()

检测NPU是否可用。

from npu_adapter import has_npu

if has_npu():
    print("NPU is available")
else:
    print("NPU is not available")

adapt_to_accelerator_device()

适配到加速设备,打印当前支持的设备。

from npu_adapter import adapt_to_accelerator_device

adapt_to_accelerator_device()
# 输出: load adapter(accelerator device, now support to run on Ascend NPU
# 或: load adapter(accelerator device, now support to run on Nvidia GPU

init_distributed_adapter()

初始化分布式训练适配,根据设备类型选择不同的后端。

from npu_adapter import init_distributed_adapter

init_distributed_adapter()
# NPU: backend="cpu:gloo,npu:hccl"
# GPU: backend="nccl"

2. 确定性模式

deterministic_on()

启用确定性算法和随机种子设置。

from npu_adapter import deterministic_on

deterministic_on()

3. 长序列注意力

get_longcontext_attention()

获取长序列注意力模块,基于yunchang实现。

长序列注意力

from npu_adapter import get_longcontext_attention

# 自动根据设备类型选择合适的实现
long_context_attn = get_longcontext_attention()
# NPU: LongContextAttention(ring_impl_type="basic_npu", attn_type=AttnType.NPU)
# GPU: LongContextAttention(ring_impl_type="basic", attn_type=AttnType.FA3)

4. 编译器后端

get_compiler_backend()

获取编译器后端,支持昇腾的融合算子自动使能。

import torch
from npu_adapter import get_compiler_backend

# NPU: torchair.get_npu_backend()
# GPU: "inductor"
model = torch.compile(model, backend=get_compiler_backend())

对于MindieSDBackend调用方式: 在入口脚本中,将transformer模块整体进行compile,可以通过如下方式使能:

pipe = FluxPipeline.from_pretrained(...)
transformer = torch.compile(pipe.transformer, backend=MindieSDBackend())
setattr(pipe, "transformer", transformer)

也可以针对单个Module针对性使用:

@torch.compile(backend=MindieSDBackend())
class FluxSingleTransformerBlock(nn.Module):

或者对forward函数使用:

class FluxSingleTransformerBlock(nn.Module):
    @torch.compile(backend=MindieSDBackend())
    def forward(...):

5. 内存格式

contiguous_for_channels_last_3d_memory_format()

适配连续内存格式,NPU仅支持contiguous格式,GPU支持channels_last_3d格式。

from npu_adapter import contiguous_for_channels_last_3d_memory_format

tensor = torch.randn(2, 3, 4, 5, 6)
result = contiguous_for_channels_last_3d_memory_format(tensor)

6. 激活函数

adapter_gelu()

GELU激活函数适配,NPU使用fast_gelu优化实现。

from npu_adapter import adapter_gelu

x = torch.randn(10, 10)
result = adapter_gelu(x)
# NPU: torch_npu.fast_gelu(x)
# GPU: 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

7. 归一化

adapter_norm()

RMS归一化适配(类方法)。

from npu_adapter import adapter_norm

class MyModule:
    Module:
        def __init__(self):
            self.weight = torch.randn(10)
            self.eps = 1e-5
        
        def forward(self, x):
            return adapter_norm(self, x)

adpater_rmsnorm()

RMS归一化适配(函数形式)。

from npu_adapter import adpater_rmsnorm

x = torch.randn(10, 10)
weight = torch.randn(10)
eps = 1e-5
result = adpater_rmsnorm(x, weight, eps)

8. 旋转位置编码

apply_rotary_pos_emb_adapter()

应用旋转位置编码到查询和键张量,自动根据设备选择实现。

from npu_adapter import apply_rotary_pos_emb_adapter

q = torch.randn(2, 4, 8, 16)
k = torch.randn(2, 4, 8, 16)
cos = torch.randn(2, 8, 16)
sin = torch.randn(2, 8, 16)

q_rotated, k_rotated = apply_rotary_pos_emb_adapter(q, k, cos, sin)

测试

项目包含完整的测试套件,覆盖NPU和GPU场景。

运行所有测试

python test/test_adapter.py

使用unittest模块运行

python -m unittest test.test_adapter -v

使用测试运行器

python test/run_tests.py

发布到 PyPI

项目内置了一个快速发布脚本:scripts/publish_pypi.sh

  1. 准备环境变量文件
cp .env.publish.example .env.publish

至少配置以下变量:

PYPI_REPOSITORY_URL=https://upload.pypi.org/legacy/
PYPI_USERNAME=__token__
PYPI_TOKEN=pypi-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
  1. 执行发布
bash scripts/publish_pypi.sh

常见用法:

# 自动把 pyproject.toml 的版本号按 patch +1
VERSION_PART=patch bash scripts/publish_pypi.sh

# 直接指定发布版本
PACKAGE_VERSION=0.2.1 bash scripts/publish_pypi.sh

# 发布到 TestPyPI
PYPI_REPOSITORY_URL=https://test.pypi.org/legacy/ bash scripts/publish_pypi.sh

说明:

  • 默认使用 python -m build --no-isolation,避免离线或受限网络环境下拉取构建依赖失败。
  • 默认会执行 twine check 校验包元数据。
  • 脚本优先读取 PYPI_TOKEN,如果未设置则回退到 PYPI_PASSWORD

测试覆盖范围

  • 工具函数测试: NPU/GPU检测、设备适配、分布式初始化、确定性模式
  • 长序列注意力测试: NPU/GPU场景的长序列注意力获取
  • 编译器后端测试: NPU/GPU场景的编译器后端获取
  • 内存格式操作测试: 连续内存格式适配
  • 张量操作测试: GELU激活、RMS归一化
  • 旋转位置编码测试: 旋转位置嵌入应用
  • 集成测试: 完整工作流测试

贡献指南

欢迎贡献代码!请遵循以下步骤:

  1. Fork本仓库
  2. 创建特性分支 (git checkout -b feature/AmazingFeature)
  3. 提交更改 (git commit -m 'Add some AmazingFeature')
  4. 推送到分支 (git push origin feature/AmazingFeature)
  5. 开启Pull Request

许可证

本项目采用开源许可证,具体请查看LICENSE文件。

联系方式

致谢

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

npu_adapter-0.2.0.tar.gz (8.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

npu_adapter-0.2.0-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file npu_adapter-0.2.0.tar.gz.

File metadata

  • Download URL: npu_adapter-0.2.0.tar.gz
  • Upload date:
  • Size: 8.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for npu_adapter-0.2.0.tar.gz
Algorithm Hash digest
SHA256 629ad450d8039d79194823505673527a21844e009b20f2ea8b3c80a1e2159326
MD5 8b9416881eced97476207333c383073d
BLAKE2b-256 256f9edd4b7cb5171ea314b9f14240e57dba960a8fa9623a027081eea8dae713

See more details on using hashes here.

File details

Details for the file npu_adapter-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: npu_adapter-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.15

File hashes

Hashes for npu_adapter-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f2e595dd67873b99467a3bc83007fff24001d8d4f2293d6dbba17fe93996ec22
MD5 d4421e72b9c1ce048117d5fbe827b8a8
BLAKE2b-256 cfd4a06ccf00f1e9f57e1034e50d65610e8b6d7abbb306bd23f4057e7eeb9d4e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page