a compact neuron network model definition grammar
Project description
NEURLINK
A compact grammar for neural network definition based on PyTorch.
Basics
Neurlink asks for a sequence of nndef
specification lines to define a flexibly connected neural network.
Each line of nndef
syntax follows the pattern ((C1, S2), ..., (Cn, Sn), NerveClass[input_selector, tag](...))
that says a nerve
takes inputs determined by input_selector
and output n
tensors, each of which has channel size C
and spatial shape S
.
input selector
Every nndef
node can select its input from the set of all previous lines. By default, input_selector
can be omitted and means outputs of previous nndef
; it can also be a integer index, a slice
, or a list
, so that one can specify multiple input nodes. The implementation of each layer is a subclass of Nerve
and can get informations of the input specifications by attribute self.input_links
.
tag
can be optionally specified to allow you refer to a nndef
specification by a string alias.
output dimensions
Output dimensions specify channels and spatial shapes seperately. The spatial shape written as a (tuple of) integer/float number is a relative down sampling ratio of the base_shape
(specified by nv.Input
). It can also be written as a string so that it evaluates to absolute shape. The dimensions are written out to give a straightforward impression of the computation flow of the entire network, and shape transformation is done automatically so that you don't bother to derive how many stride or padding you might want.
Example (resnet)
def resnet50(num_classes: int = 1000, **block_keywords):
block = BottleNeck(**block_keywords, expansion=4)
expansion = 4
return build(
[
((3, 1), nv.Input()),
((64, 2), Conv2d_ReLU_BN(7)), # 7x7 conv, stride 2
((64, 4), MaxPool2d(3)), # 3x3 maxpool, stride 2
[((64 * expansion, 4), block)] * 3, # 3 layers of residual blocks, w/o downsampling
[((128 * expansion, 8), block)] * 4, # 4 layers of residual blocks, downsampling (x2) happens at the first block
[((256 * expansion, 16), block)] * 6, # 256 * expansion is the actual output channel, the bottleneck shape is interpreted inside of the block.
[((512 * expansion, 32), block)] * 3, # finally downsamped to 1/32 of origin, you may have noted that you can easily figure out the global downsample ratio.
((512 * expansion, "(1, 1)"), AvgPool2d()), # a average pooling layer downsamples to an absolute shape.
((num_classes, "(1, 1)"), Conv2d(1)), # final linear layer.
]
)
# in tests/test_build_models.py
import torch
import neurlink
model = neurlink.resnet18()
x = torch.randn((2, 3, 224, 224))
out = model(x, output_intermediate=True)
for y in out:
print(y.shape)
$ python tests/test_build_models.py
torch.Size([2, 3, 224, 224])
torch.Size([2, 64, 112, 112])
torch.Size([2, 64, 56, 56])
torch.Size([2, 64, 56, 56])
torch.Size([2, 64, 56, 56])
torch.Size([2, 128, 28, 28])
torch.Size([2, 128, 28, 28])
torch.Size([2, 256, 14, 14])
torch.Size([2, 256, 14, 14])
torch.Size([2, 512, 7, 7])
torch.Size([2, 512, 7, 7])
torch.Size([2, 512, 1, 1])
torch.Size([2, 1000, 1, 1])
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file neurlink-0.2.0.tar.gz
.
File metadata
- Download URL: neurlink-0.2.0.tar.gz
- Upload date:
- Size: 18.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2c66874f8bd8ba2e43e83ef0f19789163efe989d7ae613ffa2e0913f70b19a39 |
|
MD5 | be542f4869387890b98b201daf5d6e9b |
|
BLAKE2b-256 | de079994d5edfd34ce554a65d32404b3f60e69fc9e758a5ea2512cba7302317a |
File details
Details for the file neurlink-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: neurlink-0.2.0-py3-none-any.whl
- Upload date:
- Size: 21.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/34.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.11.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 538ff8f8d9c95f2213bbd3228eada90e49377b193004d16f41cd78d515922417 |
|
MD5 | cc25bc1566b8453fda5108fb73e7567c |
|
BLAKE2b-256 | 2c5a9005d9c26561b1c56c03ad880a9850602db03f935d07de62f82b9ca6428b |