Skip to main content

A Wrapper for PyTorch Models

Project description

TorchWrapper

A wrapper class for a PyTorhc Model using fit and predict functions that are familiar to those who use Keras and Sklearn.

Reduces the need to write fit and evaluation functions for basic models.

Quick Start

# import the module
from torchwrapper import Wrapper

# create your module, optimizer, and criterion function
model = Model()
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.MSELos()

# wrap the model
model = Wrapper(model)

# train the network
model.fit(dataloader, optimizer, criterion, epochs=50)

With a trained model, you can predict using a PyTorch dataloader:

preds = model.predict(dataloader)

This will return a numpy array of the predictions.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchwrapper-0.1.4.tar.gz (2.6 kB view details)

Uploaded Source

File details

Details for the file torchwrapper-0.1.4.tar.gz.

File metadata

  • Download URL: torchwrapper-0.1.4.tar.gz
  • Upload date:
  • Size: 2.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.7.1

File hashes

Hashes for torchwrapper-0.1.4.tar.gz
Algorithm Hash digest
SHA256 5da3399eb7eac3c1873ed47d096e15ef89d732a16f4bbf43c1967e03e6318392
MD5 4b9bfc2a4f916c53be88f82c4654561d
BLAKE2b-256 ca1360635cd719ef59972b10657c642b99b8b01d94adb6106f94e9e0f282826f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page