Skip to main content

Model wrapper for Pytorch, which can training, predict, evaluate, etc.

Project description

Usage Sample ''''''''''''

.. code:: python

    from model_wrapper import SplitClassModelWrapper

    classes = ['class1', 'class2', 'class3'...]
    X = [[...], [...],]
    y = [0, 0, 1, 2, 1...]

    model = ...
    wrapper = SplitClassModelWrapper(model, classes=classes)
    wrapper.train(X, y, val_size=0.2)

    X_test = [[...], [...],]
    y_test = [0, 1, 1, 2, 1...]
    result = wrapper.evaluate(X_test, y_test)
    # 0.953125

    result = wrapper.predict(X_test)
    # [0, 1]

    result = wrapper.predict_classes(X_test)
    # ['class1', 'class2']

    result = wrapper.predict_proba(X_test)
    # ([0, 1], array([0.99439645, 0.99190724], dtype=float32))

    result = wrapper.predict_classes_proba(X_test)
    # (['class1', 'class2'], array([0.99439645, 0.99190724], dtype=float32))

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

model-wrapper-0.2.2.tar.gz (13.9 kB view details)

Uploaded Source

File details

Details for the file model-wrapper-0.2.2.tar.gz.

File metadata

  • Download URL: model-wrapper-0.2.2.tar.gz
  • Upload date:
  • Size: 13.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for model-wrapper-0.2.2.tar.gz
Algorithm Hash digest
SHA256 e16e139b1969af5c65f2bb0a195fc257f9f338d617dbd745ca272a59b1cd90bf
MD5 10ea372915563fe56f33e27a98148f76
BLAKE2b-256 4dc381ec4d50bd45ec9e55d0900e6a6a0da77887cf59ccc348fbc075f7c5f415

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page