Skip to main content

Model wrapper for Pytorch, which can training, predict, evaluate, etc.

Project description

Usage Sample ''''''''''''

.. code:: python

    from model_wrapper import SplitClassifyModelWrapper

    classes = ['class1', 'class2', 'class3'...]
    X = [[...], [...],]
    y = [0, 0, 1, 2, 1...]

    model = ...
    wrapper = SplitClassifyModelWrapper(model, classes=classes)
    wrapper.train(X, y, val_size=0.2)

    X_test = [[...], [...],]
    y_test = [0, 1, 1, 2, 1...]
    result = wrapper.evaluate(X_test, y_test)
    # 0.953125

    result = wrapper.predict(X_test)
    # [0, 1]

    result = wrapper.predict_classes(X_test)
    # ['class1', 'class2']

    result = wrapper.predict_proba(X_test)
    # ([0, 1], array([0.99439645, 0.99190724], dtype=float32))

    result = wrapper.predict_classes_proba(X_test)
    # (['class1', 'class2'], array([0.99439645, 0.99190724], dtype=float32))

Project details


Release history Release notifications | RSS feed

This version

0.9.2

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

model-wrapper-0.9.2.tar.gz (19.9 kB view details)

Uploaded Source

File details

Details for the file model-wrapper-0.9.2.tar.gz.

File metadata

  • Download URL: model-wrapper-0.9.2.tar.gz
  • Upload date:
  • Size: 19.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for model-wrapper-0.9.2.tar.gz
Algorithm Hash digest
SHA256 b19c8f193f508f972ca41276ceb48613c2554184bd6e08d43b059862d7edee76
MD5 f9b9e35602b0ffa03a789ca2c744f4c3
BLAKE2b-256 f289aa71bab4e313912a306978108daafddce230e3f12c2a2b6d5e1a02fa3b94

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page