Skip to main content

Some utilities and wrappers for Neural Network Models

Project description

nnwrapper

A light toolbox with utilities and wrappers for Neural Network Models

Install

# for published version
pip install -U nnlite

# or developing version
pip install -U git+https://github.com/huangyh09/nnlite

Quick Usage

from nnlite import NNWrapper
from functools import partial

torch.manual_seed(0)
dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'

## VAE model (one hidden layer, dim=64), loss, and optimizer
model = nnwrapper.models.VAE_base(1838, 32, hidden_dims=[64], device=dev)
criterion = partial(nnwrapper.models.Loss_VAE_Gaussian, beta=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.95)

## NNWrapper for model training
my_wrapper = NNWrapper(model, criterion, optimizer, device=dev)
my_wrapper.fit(train_loader, epoch=3000, validation_loader=None, verbose=False)
my_wrapper.predict(test_loader)

plt.plot(my_wrapper.train_losses)

Examples

See the examples folder, including

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nnlite-0.0.2.tar.gz (2.0 MB view details)

Uploaded Source

File details

Details for the file nnlite-0.0.2.tar.gz.

File metadata

  • Download URL: nnlite-0.0.2.tar.gz
  • Upload date:
  • Size: 2.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.12.2

File hashes

Hashes for nnlite-0.0.2.tar.gz
Algorithm Hash digest
SHA256 2750b90048d2629378313b45d086074c2321eb3166c47805b0252550a2b50cab
MD5 be66a721d752b016e59e0e3cda0aa8ac
BLAKE2b-256 d3bc1931dd8c2cb02a456cc00f55ba7e1fb8a769cb3bd90491491ec4fc98ff25

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page