Skip to main content

pnnx is an open standard for PyTorch model interoperability.

Project description

PNNX

PyTorch Neural Network eXchange(PNNX) is an open standard for PyTorch model interoperability. PNNX provides an open model format for PyTorch. It defines computation graph as well as high level operators strictly matches PyTorch.

Rationale

PyTorch is currently one of the most popular machine learning frameworks. We need to deploy the trained AI model to various hardware and environments more conveniently and easily.

Before PNNX, we had the following methods:

  1. export to ONNX, and deploy with ONNX-runtime
  2. export to ONNX, and convert onnx to inference-framework specific format, and deploy with TensorRT/OpenVINO/ncnn/etc.
  3. export to TorchScript, and deploy with libtorch

As far as we know, ONNX has the ability to express the PyTorch model and it is an open standard. People usually use ONNX as an intermediate representation between PyTorch and the inference platform. However, ONNX still has the following fatal problems, which makes the birth of PNNX necessary:

  1. ONNX does not have a human-readable and editable file representation, making it difficult for users to easily modify the computation graph or add custom operators.
  2. The operator definition of ONNX is not completely in accordance with PyTorch. When exporting some PyTorch operators, glue operators are often added passively by ONNX, which makes the computation graph inconsistent with PyTorch and may impact the inference efficiency.
  3. There are a large number of additional parameters designed to be compatible with various ML frameworks in the operator definition in ONNX. These parameters increase the burden of inference implementation on hardware and software.

PNNX tries to define a set of operators and a simple and easy-to-use format that are completely contrasted with the python api of PyTorch, so that the conversion and interoperability of PyTorch models are more convenient.

Features

  1. Human readable and editable format
  2. Plain model binary in storage zip
  3. One-to-one mapping of PNNX operators and PyTorch python api
  4. Preserve math expression as one operator
  5. Preserve torch function as one operator
  6. Preserve miscellaneous module as one operator
  7. Inference via exported PyTorch python code
  8. Tensor shape propagation
  9. Model optimization
  10. Custom operator support

Build TorchScript to PNNX converter

  1. Install PyTorch and TorchVision c++ library
  2. Build PNNX with cmake

Usage

  1. Export your model to TorchScript
import torch
import torchvision.models as models

net = models.resnet18(pretrained=True)
net = net.eval()

x = torch.rand(1, 3, 224, 224)

# You could try disabling checking when tracing raises error
# mod = torch.jit.trace(net, x, check_trace=False)
mod = torch.jit.trace(net, x)

mod.save("resnet18.pt")
  1. Convert TorchScript to PNNX
pnnx resnet18.pt inputshape=[1,3,224,224]

Normally, you will get seven files

resnet18.pnnx.param PNNX graph definition

resnet18.pnnx.bin PNNX model weight

resnet18_pnnx.py PyTorch script for inference, the python code for model construction and weight initialization

resnet18.pnnx.onnx PNNX model in onnx format

resnet18.ncnn.param ncnn graph definition

resnet18.ncnn.bin ncnn model weight

resnet18_ncnn.py pyncnn script for inference

  1. Visualize PNNX with Netron

Open https://netron.app/ in browser, and drag resnet18.pnnx.param into it.

  1. PNNX command line options
Usage: pnnx [model.pt] [(key=value)...]
  pnnxparam=model.pnnx.param
  pnnxbin=model.pnnx.bin
  pnnxpy=model_pnnx.py
  pnnxonnx=model.pnnx.onnx
  ncnnparam=model.ncnn.param
  ncnnbin=model.ncnn.bin
  ncnnpy=model_ncnn.py
  fp16=1
  optlevel=2
  device=cpu/gpu
  inputshape=[1,3,224,224],...
  inputshape2=[1,3,320,320],...
  customop=/home/nihui/.cache/torch_extensions/fused/fused.so,...
  moduleop=models.common.Focus,models.yolo.Detect,...
Sample usage: pnnx mobilenet_v2.pt inputshape=[1,3,224,224]
              pnnx yolov5s.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] device=gpu moduleop=models.common.Focus,models.yolo.Detect

Parameters:

pnnxparam (default="*.pnnx.param", * is the model name): PNNX graph definition file

pnnxbin (default="*.pnnx.bin"): PNNX model weight

pnnxpy (default="*_pnnx.py"): PyTorch script for inference, including model construction and weight initialization code

pnnxonnx (default="*.pnnx.onnx"): PNNX model in onnx format

ncnnparam (default="*.ncnn.param"): ncnn graph definition

ncnnbin (default="*.ncnn.bin"): ncnn model weight

ncnnpy (default="*_ncnn.py"): pyncnn script for inference

fp16 (default=1): save ncnn weight and onnx in fp16 data type

optlevel (default=2): graph optimization level

Option Optimization level
0 do not apply optimization
1 optimization for inference
2 optimization more for inference

device (default="cpu"): device type for the input in TorchScript model, cpu or gpu

inputshape (Optional): shapes of model inputs. It is used to resolve tensor shapes in model graph. for example, [1,3,224,224] for the model with only 1 input, [1,3,224,224],[1,3,224,224] for the model that have 2 inputs.

inputshape2 (Optional): shapes of alternative model inputs, the format is identical to inputshape. Usually, it is used with inputshape to resolve dynamic shape (-1) in model graph.

customop (Optional): list of Torch extensions (dynamic library) for custom operators, separated by ",". For example, /home/nihui/.cache/torch_extensions/fused/fused.so,...

moduleop (Optional): list of modules to keep as one big operator, separated by ",". for example, models.common.Focus,models.yolo.Detect

The pnnx.param format

example

7767517
4 3
pnnx.Input      input       0 1 0
nn.Conv2d       conv_0      1 1 0 1 bias=1 dilation=(1,1) groups=1 in_channels=12 kernel_size=(3,3) out_channels=16 padding=(0,0) stride=(1,1) @bias=(16)f32 @weight=(16,12,3,3)f32
nn.Conv2d       conv_1      1 1 1 2 bias=1 dilation=(1,1) groups=1 in_channels=16 kernel_size=(2,2) out_channels=20 padding=(2,2) stride=(2,2) @bias=(20)f32 @weight=(20,16,2,2)f32
pnnx.Output     output      1 0 2

overview

[magic]
  • magic number : 7767517
[operator count] [operand count]
  • operator count : count of the operator line follows
  • operand count : count of all operands

operator line

[type] [name] [input count] [output count] [input operands] [output operands] [operator params]
  • type : type name, such as Conv2d ReLU etc
  • name : name of this operator
  • input count : count of the operands this operator needs as input
  • output count : count of the operands this operator produces as output
  • input operands : name list of all the input blob names, separated by space
  • output operands : name list of all the output blob names, separated by space
  • operator params : key=value pair list, separated by space, operator weights are prefixed by @ symbol, tensor shapes are prefixed by # symbol, input parameter keys are prefixed by $

The pnnx.bin format

pnnx.bin file is a zip file with store-only mode(no compression)

weight binary file has its name composed by operator name and weight name

For example, nn.Conv2d conv_0 1 1 0 1 bias=1 dilation=(1,1) groups=1 in_channels=12 kernel_size=(3,3) out_channels=16 padding=(0,0) stride=(1,1) @bias=(16) @weight=(16,12,3,3) would pull conv_0.weight and conv_0.bias into pnnx.bin zip archive.

weight binaries can be listed or modified with any archive application eg. 7zip

pnnx.bin

PNNX operator

PNNX always preserve operators from what PyTorch python api provides.

Here is the netron visualization comparison among ONNX, TorchScript and PNNX with the original PyTorch python code shown.

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.attention = nn.MultiheadAttention(embed_dim=256, num_heads=32)

    def forward(self, x):
        x, _ = self.attention(x, x, x)
        return x
ONNX TorchScript PNNX
MultiheadAttention.onnx MultiheadAttention.pt MultiheadAttention.pnnx

PNNX expression operator

PNNX trys to preserve expression from what PyTorch python code writes.

Here is the netron visualization comparison among ONNX, TorchScript and PNNX with the original PyTorch python code shown.

import torch

def foo(x, y):
    return torch.sqrt((2 * x + y) / 12)
ONNX TorchScript PNNX
math.onnx math.pt math.pnnx

PNNX torch function operator

PNNX trys to preserve torch functions and Tensor member functions as one operator from what PyTorch python api provides.

Here is the netron visualization comparison among ONNX, TorchScript and PNNX with the original PyTorch python code shown.

import torch
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x):
        x = F.normalize(x, eps=1e-3)
        return x
ONNX TorchScript PNNX
function.onnx function.pt function.pnnx

PNNX module operator

Users could ask PNNX to keep module as one big operator when it has complex logic.

The process is optional and could be enabled via moduleop command line option.

After pass_level0, all modules will be presented in terminal output, then you can pick the interesting ones as module operators.

############# pass_level0
inline module = models.common.Bottleneck
inline module = models.common.C3
inline module = models.common.Concat
inline module = models.common.Conv
inline module = models.common.Focus
inline module = models.common.SPP
inline module = models.yolo.Detect
inline module = utils.activations.SiLU
pnnx yolov5s.pt inputshape=[1,3,640,640] moduleop=models.common.Focus,models.yolo.Detect

Here is the netron visualization comparison among ONNX, TorchScript and PNNX with the original PyTorch python code shown.

import torch
import torch.nn as nn

class Focus(nn.Module):
    # Focus wh information into c-space
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = Conv(c1 * 4, c2, k, s, p, g, act)

    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
ONNX TorchScript PNNX PNNX with module operator
focus.onnx focus.pt focus.pnnx focus.pnnx2

PNNX python inference

A python script will be generated by default when converting TorchScript to pnnx.

This script is the python code representation of PNNX and can be used for model inference.

There are some utility functions for loading weight binary from pnnx.bin.

You can even export the model TorchScript AGAIN from this generated code!

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.linear_0 = nn.Linear(in_features=128, out_features=256, bias=True)
        self.linear_1 = nn.Linear(in_features=256, out_features=4, bias=True)

    def forward(self, x):
        x = self.linear_0(x)
        x = F.leaky_relu(x, 0.15)
        x = self.linear_1(x)
        return x
import os
import numpy as np
import tempfile, zipfile
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.linear_0 = nn.Linear(bias=True, in_features=128, out_features=256)
        self.linear_1 = nn.Linear(bias=True, in_features=256, out_features=4)

        archive = zipfile.ZipFile('../../function.pnnx.bin', 'r')
        self.linear_0.bias = self.load_pnnx_bin_as_parameter(archive, 'linear_0.bias', (256), 'float32')
        self.linear_0.weight = self.load_pnnx_bin_as_parameter(archive, 'linear_0.weight', (256,128), 'float32')
        self.linear_1.bias = self.load_pnnx_bin_as_parameter(archive, 'linear_1.bias', (4), 'float32')
        self.linear_1.weight = self.load_pnnx_bin_as_parameter(archive, 'linear_1.weight', (4,256), 'float32')
        archive.close()

    def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype):
        return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype))

    def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):
        fd, tmppath = tempfile.mkstemp()
        with os.fdopen(fd, 'wb') as tmpf, archive.open(key) as keyfile:
            tmpf.write(keyfile.read())
        m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()
        os.remove(tmppath)
        return torch.from_numpy(m)

    def forward(self, v_x_1):
        v_7 = self.linear_0(v_x_1)
        v_input_1 = F.leaky_relu(input=v_7, negative_slope=0.150000)
        v_12 = self.linear_1(v_input_1)
        return v_12

PNNX shape propagation

Users could ask PNNX to resolve all tensor shapes in model graph and constify some common expressions involved when tensor shapes are known.

The process is optional and could be enabled via inputshape command line option.

pnnx shufflenet_v2_x1_0.pt inputshape=[1,3,224,224]
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
    batchsize, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups

    # reshape
    x = x.view(batchsize, groups, channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x
without shape propagation with shape propagation
noshapeinfer shapeinfer

PNNX model optimization

ONNX TorchScript PNNX without optimization PNNX with optimization
optlessonnx optlesspt optless opt

PNNX custom operator

import os

import torch
from torch.autograd import Function
from torch.utils.cpp_extension import load, _import_module_from_library

module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
    'upfirdn2d',
    sources=[
        os.path.join(module_path, 'upfirdn2d.cpp'),
        os.path.join(module_path, 'upfirdn2d_kernel.cu'),
    ],
    is_python_module=False
)

def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    pad_x0 = pad[0]
    pad_x1 = pad[1]
    pad_y0 = pad[0]
    pad_y1 = pad[1]

    kernel_h, kernel_w = kernel.shape
    batch, channel, in_h, in_w = input.shape

    input = input.reshape(-1, in_h, in_w, 1)

    out_h = (in_h * up + pad_y0 + pad_y1 - kernel_h) // down + 1
    out_w = (in_w * up + pad_x0 + pad_x1 - kernel_w) // down + 1

    out = torch.ops.upfirdn2d_op.upfirdn2d(input, kernel, up, up, down, down, pad_x0, pad_x1, pad_y0, pad_y1)

    out = out.view(-1, channel, out_h, out_w)

    return out
#include <torch/extension.h>

torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
                        int64_t up_x, int64_t up_y, int64_t down_x, int64_t down_y,
                        int64_t pad_x0, int64_t pad_x1, int64_t pad_y0, int64_t pad_y1) {
    // operator body
}

TORCH_LIBRARY(upfirdn2d_op, m) {
    m.def("upfirdn2d", upfirdn2d);
}

Supported PyTorch operator status

torch.nn Is Supported Export to ncnn
nn.AdaptiveAvgPool1d :heavy_check_mark: :heavy_check_mark:
nn.AdaptiveAvgPool2d :heavy_check_mark: :heavy_check_mark:
nn.AdaptiveAvgPool3d :heavy_check_mark: :heavy_check_mark:
nn.AdaptiveMaxPool1d :heavy_check_mark: :heavy_check_mark:
nn.AdaptiveMaxPool2d :heavy_check_mark: :heavy_check_mark:
nn.AdaptiveMaxPool3d :heavy_check_mark: :heavy_check_mark:
nn.AlphaDropout :heavy_check_mark: :heavy_check_mark:
nn.AvgPool1d :heavy_check_mark: :heavy_check_mark:*
nn.AvgPool2d :heavy_check_mark: :heavy_check_mark:*
nn.AvgPool3d :heavy_check_mark: :heavy_check_mark:*
nn.BatchNorm1d :heavy_check_mark: :heavy_check_mark:
nn.BatchNorm2d :heavy_check_mark: :heavy_check_mark:
nn.BatchNorm3d :heavy_check_mark: :heavy_check_mark:
nn.Bilinear
nn.CELU :heavy_check_mark: :heavy_check_mark:
nn.ChannelShuffle :heavy_check_mark: :heavy_check_mark:
nn.ConstantPad1d :heavy_check_mark: :heavy_check_mark:
nn.ConstantPad2d :heavy_check_mark: :heavy_check_mark:
nn.ConstantPad3d :heavy_check_mark: :heavy_check_mark:
nn.Conv1d :heavy_check_mark: :heavy_check_mark:
nn.Conv2d :heavy_check_mark: :heavy_check_mark:
nn.Conv3d :heavy_check_mark: :heavy_check_mark:
nn.ConvTranspose1d :heavy_check_mark: :heavy_check_mark:
nn.ConvTranspose2d :heavy_check_mark: :heavy_check_mark:
nn.ConvTranspose3d :heavy_check_mark: :heavy_check_mark:
nn.CosineSimilarity
nn.Dropout :heavy_check_mark: :heavy_check_mark:
nn.Dropout2d :heavy_check_mark: :heavy_check_mark:
nn.Dropout3d :heavy_check_mark: :heavy_check_mark:
nn.ELU :heavy_check_mark: :heavy_check_mark:
nn.Embedding :heavy_check_mark: :heavy_check_mark:
nn.EmbeddingBag
nn.Flatten :heavy_check_mark:
nn.Fold :heavy_check_mark: :heavy_check_mark:
nn.FractionalMaxPool2d
nn.FractionalMaxPool3d
nn.GELU :heavy_check_mark: :heavy_check_mark:
nn.GLU :heavy_check_mark: :heavy_check_mark:
nn.GroupNorm :heavy_check_mark: :heavy_check_mark:
nn.GRU :heavy_check_mark: :heavy_check_mark:
nn.GRUCell
nn.Hardshrink :heavy_check_mark:
nn.Hardsigmoid :heavy_check_mark: :heavy_check_mark:
nn.Hardswish :heavy_check_mark: :heavy_check_mark:
nn.Hardtanh :heavy_check_mark: :heavy_check_mark:
nn.Identity
nn.InstanceNorm1d :heavy_check_mark:
nn.InstanceNorm2d :heavy_check_mark: :heavy_check_mark:
nn.InstanceNorm3d :heavy_check_mark:
nn.LayerNorm :heavy_check_mark: :heavy_check_mark:
nn.LazyBatchNorm1d
nn.LazyBatchNorm2d
nn.LazyBatchNorm3d
nn.LazyConv1d
nn.LazyConv2d
nn.LazyConv3d
nn.LazyConvTranspose1d
nn.LazyConvTranspose2d
nn.LazyConvTranspose3d
nn.LazyLinear
nn.LeakyReLU :heavy_check_mark: :heavy_check_mark:
nn.Linear :heavy_check_mark: :heavy_check_mark:
nn.LocalResponseNorm :heavy_check_mark: :heavy_check_mark:
nn.LogSigmoid :heavy_check_mark: :heavy_check_mark:
nn.LogSoftmax :heavy_check_mark: :heavy_check_mark:
nn.LPPool1d :heavy_check_mark:
nn.LPPool2d :heavy_check_mark:
nn.LSTM :heavy_check_mark: :heavy_check_mark:
nn.LSTMCell
nn.MaxPool1d :heavy_check_mark: :heavy_check_mark:
nn.MaxPool2d :heavy_check_mark: :heavy_check_mark:
nn.MaxPool3d :heavy_check_mark: :heavy_check_mark:
nn.MaxUnpool1d
nn.MaxUnpool2d
nn.MaxUnpool3d
nn.Mish :heavy_check_mark: :heavy_check_mark:
nn.MultiheadAttention :heavy_check_mark: :heavy_check_mark:*
nn.PairwiseDistance
nn.PixelShuffle :heavy_check_mark: :heavy_check_mark:
nn.PixelUnshuffle :heavy_check_mark: :heavy_check_mark:
nn.PReLU :heavy_check_mark: :heavy_check_mark:
nn.ReflectionPad1d :heavy_check_mark: :heavy_check_mark:
nn.ReflectionPad2d :heavy_check_mark: :heavy_check_mark:
nn.ReLU :heavy_check_mark: :heavy_check_mark:
nn.ReLU6 :heavy_check_mark: :heavy_check_mark:
nn.ReplicationPad1d :heavy_check_mark: :heavy_check_mark:
nn.ReplicationPad2d :heavy_check_mark: :heavy_check_mark:
nn.ReplicationPad3d :heavy_check_mark:
nn.RNN :heavy_check_mark: :heavy_check_mark:*
nn.RNNBase
nn.RNNCell
nn.RReLU :heavy_check_mark:
nn.SELU :heavy_check_mark: :heavy_check_mark:
nn.Sigmoid :heavy_check_mark: :heavy_check_mark:
nn.SiLU :heavy_check_mark: :heavy_check_mark:
nn.Softmax :heavy_check_mark: :heavy_check_mark:
nn.Softmax2d :heavy_check_mark: :heavy_check_mark:
nn.Softmin :heavy_check_mark:
nn.Softplus :heavy_check_mark:
nn.Softshrink :heavy_check_mark:
nn.Softsign :heavy_check_mark:
nn.SyncBatchNorm
nn.Tanh :heavy_check_mark: :heavy_check_mark:
nn.Tanhshrink :heavy_check_mark:
nn.Threshold :heavy_check_mark:
nn.Transformer
nn.TransformerDecoder
nn.TransformerDecoderLayer
nn.TransformerEncoder
nn.TransformerEncoderLayer
nn.Unflatten
nn.Unfold :heavy_check_mark: :heavy_check_mark:
nn.Upsample :heavy_check_mark: :heavy_check_mark:
nn.UpsamplingBilinear2d :heavy_check_mark: :heavy_check_mark:
nn.UpsamplingNearest2d :heavy_check_mark: :heavy_check_mark:
nn.ZeroPad2d :heavy_check_mark: :heavy_check_mark:
torch.nn.functional Is Supported Export to ncnn
F.adaptive_avg_pool1d :heavy_check_mark: :heavy_check_mark:
F.adaptive_avg_pool2d :heavy_check_mark: :heavy_check_mark:
F.adaptive_avg_pool3d :heavy_check_mark: :heavy_check_mark:
F.adaptive_max_pool1d :heavy_check_mark: :heavy_check_mark:
F.adaptive_max_pool2d :heavy_check_mark: :heavy_check_mark:
F.adaptive_max_pool3d :heavy_check_mark: :heavy_check_mark:
F.affine_grid :heavy_check_mark:
F.alpha_dropout :heavy_check_mark: :heavy_check_mark:
F.avg_pool1d :heavy_check_mark: :heavy_check_mark:*
F.avg_pool2d :heavy_check_mark: :heavy_check_mark:*
F.avg_pool3d :heavy_check_mark: :heavy_check_mark:*
F.batch_norm :heavy_check_mark: :heavy_check_mark:
F.bilinear
F.celu :heavy_check_mark:
F.conv1d :heavy_check_mark: :heavy_check_mark:
F.conv2d :heavy_check_mark: :heavy_check_mark:
F.conv3d :heavy_check_mark: :heavy_check_mark:
F.conv_transpose1d :heavy_check_mark: :heavy_check_mark:
F.conv_transpose2d :heavy_check_mark: :heavy_check_mark:
F.conv_transpose3d :heavy_check_mark: :heavy_check_mark:
F.cosine_similarity
F.dropout :heavy_check_mark: :heavy_check_mark:
F.dropout2d :heavy_check_mark: :heavy_check_mark:
F.dropout3d :heavy_check_mark: :heavy_check_mark:
F.elu :heavy_check_mark: :heavy_check_mark:
F.elu_ :heavy_check_mark: :heavy_check_mark:
F.embedding :heavy_check_mark: :heavy_check_mark:
F.embedding_bag
F.feature_alpha_dropout :heavy_check_mark: :heavy_check_mark:
F.fold :heavy_check_mark: :heavy_check_mark:
F.fractional_max_pool2d
F.fractional_max_pool3d
F.gelu :heavy_check_mark: :heavy_check_mark:
F.glu :heavy_check_mark: :heavy_check_mark:
F.grid_sample :heavy_check_mark: :heavy_check_mark:
F.group_norm :heavy_check_mark: :heavy_check_mark:
F.gumbel_softmax
F.hardshrink :heavy_check_mark:
F.hardsigmoid :heavy_check_mark: :heavy_check_mark:
F.hardswish :heavy_check_mark: :heavy_check_mark:
F.hardtanh :heavy_check_mark: :heavy_check_mark:
F.hardtanh_ :heavy_check_mark: :heavy_check_mark:
F.instance_norm :heavy_check_mark: :heavy_check_mark:
F.interpolate :heavy_check_mark: :heavy_check_mark:
F.layer_norm :heavy_check_mark: :heavy_check_mark:
F.leaky_relu :heavy_check_mark: :heavy_check_mark:
F.leaky_relu_ :heavy_check_mark: :heavy_check_mark:
F.linear :heavy_check_mark: :heavy_check_mark:*
F.local_response_norm :heavy_check_mark: :heavy_check_mark:
F.logsigmoid :heavy_check_mark: :heavy_check_mark:
F.log_softmax :heavy_check_mark: :heavy_check_mark:
F.lp_pool1d :heavy_check_mark:
F.lp_pool2d :heavy_check_mark:
F.max_pool1d :heavy_check_mark: :heavy_check_mark:
F.max_pool2d :heavy_check_mark: :heavy_check_mark:
F.max_pool3d :heavy_check_mark: :heavy_check_mark:
F.max_unpool1d
F.max_unpool2d
F.max_unpool3d
F.mish :heavy_check_mark: :heavy_check_mark:
F.normalize :heavy_check_mark: :heavy_check_mark:
F.one_hot
F.pad :heavy_check_mark: :heavy_check_mark:
F.pairwise_distance
F.pdist
F.pixel_shuffle :heavy_check_mark: :heavy_check_mark:
F.pixel_unshuffle :heavy_check_mark: :heavy_check_mark:
F.prelu :heavy_check_mark: :heavy_check_mark:
F.relu :heavy_check_mark: :heavy_check_mark:
F.relu_ :heavy_check_mark: :heavy_check_mark:
F.relu6 :heavy_check_mark: :heavy_check_mark:
F.rrelu :heavy_check_mark:
F.rrelu_ :heavy_check_mark:
F.scaled_dot_product_attention :heavy_check_mark:
F.selu :heavy_check_mark: :heavy_check_mark:
F.sigmoid :heavy_check_mark: :heavy_check_mark:
F.silu :heavy_check_mark: :heavy_check_mark:
F.softmax :heavy_check_mark: :heavy_check_mark:
F.softmin :heavy_check_mark:
F.softplus :heavy_check_mark:
F.softshrink :heavy_check_mark:
F.softsign :heavy_check_mark:
F.tanh :heavy_check_mark: :heavy_check_mark:
F.tanhshrink :heavy_check_mark:
F.threshold :heavy_check_mark:
F.threshold_ :heavy_check_mark:
F.unfold :heavy_check_mark: :heavy_check_mark:
F.upsample :heavy_check_mark: :heavy_check_mark:
F.upsample_bilinear :heavy_check_mark: :heavy_check_mark:
F.upsample_nearest :heavy_check_mark: :heavy_check_mark:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

pnnx-20231218-cp311-cp311-win_amd64.whl (12.9 MB view details)

Uploaded CPython 3.11 Windows x86-64

pnnx-20231218-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.7 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.17+ x86-64

pnnx-20231218-cp311-cp311-macosx_10_9_x86_64.whl (19.7 MB view details)

Uploaded CPython 3.11 macOS 10.9+ x86-64

pnnx-20231218-cp310-cp310-win_amd64.whl (12.9 MB view details)

Uploaded CPython 3.10 Windows x86-64

pnnx-20231218-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.7 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.17+ x86-64

pnnx-20231218-cp310-cp310-macosx_10_9_x86_64.whl (19.7 MB view details)

Uploaded CPython 3.10 macOS 10.9+ x86-64

pnnx-20231218-cp39-cp39-win_amd64.whl (12.9 MB view details)

Uploaded CPython 3.9 Windows x86-64

pnnx-20231218-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.7 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.17+ x86-64

pnnx-20231218-cp39-cp39-macosx_10_9_x86_64.whl (19.7 MB view details)

Uploaded CPython 3.9 macOS 10.9+ x86-64

pnnx-20231218-cp38-cp38-win_amd64.whl (12.9 MB view details)

Uploaded CPython 3.8 Windows x86-64

pnnx-20231218-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.7 MB view details)

Uploaded CPython 3.8 manylinux: glibc 2.17+ x86-64

pnnx-20231218-cp38-cp38-macosx_10_9_x86_64.whl (19.7 MB view details)

Uploaded CPython 3.8 macOS 10.9+ x86-64

pnnx-20231218-cp37-cp37m-win_amd64.whl (12.9 MB view details)

Uploaded CPython 3.7m Windows x86-64

pnnx-20231218-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.7 MB view details)

Uploaded CPython 3.7m manylinux: glibc 2.17+ x86-64

pnnx-20231218-cp37-cp37m-macosx_10_9_x86_64.whl (19.7 MB view details)

Uploaded CPython 3.7m macOS 10.9+ x86-64

File details

Details for the file pnnx-20231218-cp311-cp311-win_amd64.whl.

File metadata

  • Download URL: pnnx-20231218-cp311-cp311-win_amd64.whl
  • Upload date:
  • Size: 12.9 MB
  • Tags: CPython 3.11, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for pnnx-20231218-cp311-cp311-win_amd64.whl
Algorithm Hash digest
SHA256 8d0471ad69219a544e650829d4255a12ab09cc66577e3b977db8d94e7a257f81
MD5 753f654a68898969cf93a3fb8b982885
BLAKE2b-256 6765be082c3ff0fe4accafd77e0fe161006195f2fe985850c05ae1fce6f4aaed

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7b33893a05491342a6dc4c68783a1da0acedded114c973c4ac2a0f29259759b0
MD5 128f5d7647f63cb6242e7d1b339ece2e
BLAKE2b-256 566267fce42d4367f2a2d8413bf102e397b605d0aa3425f9c48ed57497837d99

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp311-cp311-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp311-cp311-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 750f23b46b2633e1705a46070c92a4648d6fd51ecdd22de8379f390542ff2fc0
MD5 0d1bb8ccccf706f9e118ce16d694221c
BLAKE2b-256 63df2cacfe27de3e18a56f2048b65b458f10d6c783888e35a381ffaf85fdabf8

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp310-cp310-win_amd64.whl.

File metadata

  • Download URL: pnnx-20231218-cp310-cp310-win_amd64.whl
  • Upload date:
  • Size: 12.9 MB
  • Tags: CPython 3.10, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for pnnx-20231218-cp310-cp310-win_amd64.whl
Algorithm Hash digest
SHA256 0709f22397d70df701d940089c189926a20ceb2a38fb95a3814c839346e0261a
MD5 bf9db9470b7843be62553124e3e421d9
BLAKE2b-256 ae248fb12a05f408351c212381153574c9f4061f5182b1ce2f0b442b4e50905d

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 4e647f752a26f76d83f21ce3ff8d983842372273e03ad6049013623cdabdb21d
MD5 2ab4c82532cc0a47c170cb1704438594
BLAKE2b-256 a48183f9b0e0508549e6a9a64197834972fafcb76201c76c09c6afea2b8aed40

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp310-cp310-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 46b862ae77ff20ffdfe046e70b5ac8e0ce9592117176893270c3b5091ae27384
MD5 487047c3e4b46b5678f30fdbdaa5b84a
BLAKE2b-256 237e03ea5b49315df522547ac6c19ba3ea081cf25220f70fa8437a7a0a5de934

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp39-cp39-win_amd64.whl.

File metadata

  • Download URL: pnnx-20231218-cp39-cp39-win_amd64.whl
  • Upload date:
  • Size: 12.9 MB
  • Tags: CPython 3.9, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for pnnx-20231218-cp39-cp39-win_amd64.whl
Algorithm Hash digest
SHA256 71aa2230369614023a27a42da30ec70caabe69c12018a3d439941c3f036c7ba5
MD5 7e26184d5cb58f3f93710aae686577db
BLAKE2b-256 47b601369f508d21682a114c130762ca5b050c8c134edf98f0bf1466d7c1e60e

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 8abf812f36736e199f8ca6f4d47b7132924ebf05c8f3172f4f8a565c80ea8d6b
MD5 e449de8e9ca70bd40f6ca52b3ad7c93c
BLAKE2b-256 8efa2a5d6c9abe2c7ef3c3704b3f319237c3b729d62bd0e5618a359e09b76710

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp39-cp39-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 2acfb96e8999cc70f7730581fa2cef74b19b1505c8bd1e2db0d8ae19e3a67ae9
MD5 1c6f2c26926a1fcedda8af3a73dc1fec
BLAKE2b-256 051316b6aa56768be64452abf6a04364038ec12de0b042f7ebb9095c3f21acf8

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp38-cp38-win_amd64.whl.

File metadata

  • Download URL: pnnx-20231218-cp38-cp38-win_amd64.whl
  • Upload date:
  • Size: 12.9 MB
  • Tags: CPython 3.8, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for pnnx-20231218-cp38-cp38-win_amd64.whl
Algorithm Hash digest
SHA256 7c6f332c40162a2203a7ad702719670a7efbcbea2b839159ddf309bdd1ea26f7
MD5 d17faf1f09095cdda5cc4c11ceb5425c
BLAKE2b-256 0221fbd22e20b3d4208812c8f3ff8b1ffa11ce996af4818bbde0a38353a9a5d9

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 74a39a28e50471d727f3e478439a5be4904f631fc0e6e846dd20f1190c3a77ed
MD5 99015d6c3736f2feef73e12f2730e917
BLAKE2b-256 526480a0a8ed1ed7c4e590fea873f9a4efaded763950fbd200fa2e9e014694ff

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp38-cp38-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 6b47f1fc19bfbd0f17a1704650b6cee21db4a7331adbf23cb221b8f73524c299
MD5 8228ec1e2a3f22d62880e63bd37fa740
BLAKE2b-256 4b96a4a076a907e5c27ef7535b3554ff4dc14b153cc842834ee7b9b5e0b73ab1

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp37-cp37m-win_amd64.whl.

File metadata

  • Download URL: pnnx-20231218-cp37-cp37m-win_amd64.whl
  • Upload date:
  • Size: 12.9 MB
  • Tags: CPython 3.7m, Windows x86-64
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for pnnx-20231218-cp37-cp37m-win_amd64.whl
Algorithm Hash digest
SHA256 1f6962d1efd0ef7151c0d1f8311623d39c003ca572bd6b454a5276c07a113cb4
MD5 21169cbf99c3e52dfe624c80023967c0
BLAKE2b-256 ba93edd7f5776c8539a60b71af6b9923099dc6eac4b45c14cdffc184dbad2ea7

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 43ff379b9c922510d1c0e58b2e2d7306c12153e96722e460348d55e26b577eda
MD5 d1e1d61878893690275df52cb5737f3a
BLAKE2b-256 ef7b45f06134f9698d601826b6065790a94b36d43ba325348a0fcb91a30f7a6a

See more details on using hashes here.

File details

Details for the file pnnx-20231218-cp37-cp37m-macosx_10_9_x86_64.whl.

File metadata

File hashes

Hashes for pnnx-20231218-cp37-cp37m-macosx_10_9_x86_64.whl
Algorithm Hash digest
SHA256 a205ad3609280287132a920f297829f33a3687476202d1a5262a0fc7fa371932
MD5 441efb6a270e420a0c7b617c1be673bc
BLAKE2b-256 c1107f981bd738a284900081a8cb3306dc1763b632914d095b24cdf3c7b2ec23

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page