Skip to main content

Toy Neural Network Generator.

Project description

Github CI/CD

Toy Neural Network Generator

Installation

$ pip install tnng

Simple Model Generator

#!/usr/bin/env python
import torch
import torch.nn as nn
import torchex.nn as exnn
from tnng import Generator, MultiHeadLinkedListLayer

m = MultiHeadLinkedListLayer()
# all layers can be lazy evaluation.
m.append([exnn.Linear(64), exnn.Linear(128), exnn.Linear(256)])
m.append([nn.ReLU(), nn.ELU()])
m.append([exnn.Linear(16), exnn.Linear(32), exnn.Linear(64),])
m.append([nn.ReLU(), nn.ELU()])
m.append([exnn.Linear(10)])

g = Generator(m)

x = torch.randn(128, 256)

class Model(nn.Module):
    def __init__(self, idx=0):
        super(Model, self).__init__()
        self.model = nn.ModuleList([l[0] for l in g[idx]])

    def forward(self, x):
        for m in self.model:
            x = m(x)
        return x

m = Model(0)
o = m(x)

'''
ModuleList(
  (0): Linear(in_features=256, out_features=64, bias=True)
  (1): ReLU()
  (2): Linear(in_features=64, out_features=16, bias=True)
  (3): ReLU()
  (4): Linear(in_features=16, out_features=10, bias=True)
)
'''

Multimodal Model Generator

#!/usr/bin/env python
import torch
import torch.nn as nn
import torchex.nn as exnn
from tnng import Generator, MultiHeadLinkedListLayer

m = MultiHeadLinkedListLayer()
m1 = MultiHeadLinkedListLayer()
# all layers can be lazy evaluation.
m.append([exnn.Linear(64), exnn.Linear(128), exnn.Linear(256)])
m.append([nn.ReLU(), nn.ELU()])
m.append([exnn.Linear(16), exnn.Linear(32), exnn.Linear(64),])
m.append([nn.ReLU(), nn.ELU()])

m1.append([exnn.Conv2d(16, 1), exnn.Conv2d(32, 1), exnn.Conv2d(64, 1)])
m1.append([nn.MaxPool2d(2), nn.AvgPool2d(2)])
m1.append([nn.ReLU(), nn.ELU(), nn.Identity()])
m1.append([exnn.Conv2d(32, 1), exnn.Conv2d(64, 1), exnn.Conv2d(128, 1)])
m1.append([nn.MaxPool2d(2), nn.AvgPool2d(2)])
m1.append([exnn.Flatten(),])

m = m + m1
m.append([exnn.Linear(128)])
m.append([nn.ReLU(), nn.ELU(), nn.Identity()])
m.append([exnn.Linear(10)])


g = Generator(m)
class Model(nn.Module):
    def __init__(self, idx=0):
        super(Model, self).__init__()
        self.model = g[idx]
        for layers in self.model:
            for layer in layers:
                self.add_module(f'{layer}', layer)

    def forward(self, x, img):
        for m in self.model:
            if len(m) == 2:
                if m[0] is not None:
                    x = m[0](x)
                img = m[1](img)
            elif len(m) == 1 and m[0] is None:
                x = torch.cat((x, img), 1)
            else:
                x = m[0](x)
        return x

x = torch.randn(128, 256)
img = torch.randn(128, 3, 28, 28)
m = Model()
o = m(x, img)
print(o.shape)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tnng-0.3.1.tar.gz (4.8 kB view details)

Uploaded Source

Built Distribution

tnng-0.3.1-py2.py3-none-any.whl (5.1 kB view details)

Uploaded Python 2Python 3

File details

Details for the file tnng-0.3.1.tar.gz.

File metadata

  • Download URL: tnng-0.3.1.tar.gz
  • Upload date:
  • Size: 4.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/42.0.2.post20191203 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.5

File hashes

Hashes for tnng-0.3.1.tar.gz
Algorithm Hash digest
SHA256 f7380b5df81d8ccf526a08c0e04df06cfb92364aa39f9ac887261df9e1cc3fc8
MD5 97d7124d8a0a32de9f4cffd0b4126ba3
BLAKE2b-256 4a1d21655d67eb98706a9643264834820484b745e3e3d901f04708df8fedcd7e

See more details on using hashes here.

File details

Details for the file tnng-0.3.1-py2.py3-none-any.whl.

File metadata

  • Download URL: tnng-0.3.1-py2.py3-none-any.whl
  • Upload date:
  • Size: 5.1 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/42.0.2.post20191203 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.5

File hashes

Hashes for tnng-0.3.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 d43980d7c61c0c534bcc973e1c3042200cfefcce03537ddfb0ba555d735abd1f
MD5 1b380b9d0b8b165686e7703bc6deb776
BLAKE2b-256 0f6068e601918cab51e8b0daba4e29f515c6702746cad7f4a8e238562a2178a9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page