Skip to main content

Model wrapper for Pytorch, which can training, predict, evaluate, etc.

Project description

Usage Sample ''''''''''''

.. code:: python

    from model_wrapper import SplitClassModelWrapper

    classes = ['class1', 'class2', 'class3'...]
    X = [[...], [...],]
    y = [0, 0, 1, 2, 1...]

    model = ...
    wrapper = SplitClassModelWrapper(model, classes=classes)
    wrapper.train(X, y, val_size=0.2)

    X_test = [[...], [...],]
    y_test = [0, 1, 1, 2, 1...]
    result = wrapper.evaluate(X_test, y_test)
    # 0.953125

    result = wrapper.predict(X_test)
    # [0, 1]

    result = wrapper.predict_classes(X_test)
    # ['class1', 'class2']

    result = wrapper.predict_proba(X_test)
    # ([0, 1], array([0.99439645, 0.99190724], dtype=float32))

    result = wrapper.predict_classes_proba(X_test)
    # (['class1', 'class2'], array([0.99439645, 0.99190724], dtype=float32))

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

model-wrapper-0.1.6.tar.gz (12.7 kB view details)

Uploaded Source

File details

Details for the file model-wrapper-0.1.6.tar.gz.

File metadata

  • Download URL: model-wrapper-0.1.6.tar.gz
  • Upload date:
  • Size: 12.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for model-wrapper-0.1.6.tar.gz
Algorithm Hash digest
SHA256 c8a4c8f97d03b7dee3d469e391c07f76dd2a792abbf73c4e1c904d8e7a32290c
MD5 50e71bdb7d85e54c1b8f7232f11a2dcb
BLAKE2b-256 4771d513247bb725c67d5431f3a49a6b8e460bb2a47fa347189c7f631bc113ef

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page