Skip to main content

Model wrapper for Pytorch, which can training, predict, evaluate, etc.

Project description

Usage Sample ''''''''''''

.. code:: python

    from model_wrapper import SplitClassifyModelWrapper

    classes = ['class1', 'class2', 'class3'...]
    X = [[...], [...],]
    y = [0, 0, 1, 2, 1...]

    model = ...
    wrapper = SplitClassifyModelWrapper(model, classes=classes)
    wrapper.train(X, y, val_size=0.2)

    X_test = [[...], [...],]
    y_test = [0, 1, 1, 2, 1...]
    result = wrapper.evaluate(X_test, y_test)
    # 0.953125

    result = wrapper.predict(X_test)
    # [0, 1]

    result = wrapper.predict_classes(X_test)
    # ['class1', 'class2']

    result = wrapper.predict_proba(X_test)
    # ([0, 1], array([0.99439645, 0.99190724], dtype=float32))

    result = wrapper.predict_classes_proba(X_test)
    # (['class1', 'class2'], array([0.99439645, 0.99190724], dtype=float32))

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

model-wrapper-0.8.7.tar.gz (18.3 kB view details)

Uploaded Source

File details

Details for the file model-wrapper-0.8.7.tar.gz.

File metadata

  • Download URL: model-wrapper-0.8.7.tar.gz
  • Upload date:
  • Size: 18.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for model-wrapper-0.8.7.tar.gz
Algorithm Hash digest
SHA256 71b9866aaf96732002735fa8a4bba64c2ceefdac3fcc1c8c3a3db1ac46f73f5f
MD5 91ee258f1931ac289de8cc6a39d8fc43
BLAKE2b-256 74fa6db2efa7e2fe343a1f65fabf00356d4650680bf9e9520f4d9fdb6705de6c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page