Skip to main content

Model wrapper for Pytorch, which can training, predict, evaluate, etc.

Project description

Usage Sample ''''''''''''

.. code:: python

    from model_wrapper import SplitClassifyModelWrapper

    classes = ['class1', 'class2', 'class3'...]
    X = [[...], [...],]
    y = [0, 0, 1, 2, 1...]

    model = ...
    wrapper = SplitClassifyModelWrapper(model, classes=classes)
    wrapper.train(X, y, val_size=0.2)

    X_test = [[...], [...],]
    y_test = [0, 1, 1, 2, 1...]
    result = wrapper.evaluate(X_test, y_test)
    # 0.953125

    result = wrapper.predict(X_test)
    # [0, 1]

    result = wrapper.predict_classes(X_test)
    # ['class1', 'class2']

    result = wrapper.predict_proba(X_test)
    # ([0, 1], array([0.99439645, 0.99190724], dtype=float32))

    result = wrapper.predict_classes_proba(X_test)
    # (['class1', 'class2'], array([0.99439645, 0.99190724], dtype=float32))

Project details


Release history Release notifications | RSS feed

This version

0.7.3

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

model-wrapper-0.7.3.tar.gz (17.5 kB view details)

Uploaded Source

File details

Details for the file model-wrapper-0.7.3.tar.gz.

File metadata

  • Download URL: model-wrapper-0.7.3.tar.gz
  • Upload date:
  • Size: 17.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for model-wrapper-0.7.3.tar.gz
Algorithm Hash digest
SHA256 c5e77457c3b4d4f46ee9a3812a112d08da9be8616d487d79aa94efe66c771e61
MD5 8c0dcac9d2d952981c469e84a8ed0404
BLAKE2b-256 cc5807be0ead4103c6f0c35cf45d7fc54341178449508c629d71112a022f4b69

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page