Skip to main content

Model wrapper for Pytorch, which can training, predict, evaluate etc.

Project description

Usage Sample ''''''''''''

.. code:: python

    from model_wrapper import SplitClassModelWrapper

    classes = ['class1', 'class2', 'class3'...]
    X = [[...], [...],]
    y = [0, 0, 1, 2, 1...]

    model = ...
    model_wrapper = SplitClassModelWrapper(model, classes=classes)
    model_wrapper.train(X, y, val_size=0.2)

    X_test = [[...], [...],]
    y_test = [0, 1, 1, 2, 1...]
    result = model_wrapper.evaluate(X_test, y_test)
    # 0.953125

    result = model_wrapper.predict(X_test)
    # [0, 1]

    result = model_wrapper.predict_classes(X_test)
    # ['class1', 'class2']

    result = model_wrapper.predict_proba(X_test)
    # ([0, 1], array([0.99439645, 0.99190724], dtype=float32))

    result = model_wrapper.predict_classes_proba(X_test)
    # (['class1', 'class2'], array([0.99439645, 0.99190724], dtype=float32))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

model-wrapper-0.0.1.tar.gz (11.5 kB view details)

Uploaded Source

File details

Details for the file model-wrapper-0.0.1.tar.gz.

File metadata

  • Download URL: model-wrapper-0.0.1.tar.gz
  • Upload date:
  • Size: 11.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for model-wrapper-0.0.1.tar.gz
Algorithm Hash digest
SHA256 4a47c3a7cb55c790b7f5ba961f6c1ae4093f6e3a102109d631babe5cc23c47fc
MD5 871e1d8c3a84a5d631ba643ef33ad398
BLAKE2b-256 cd75dec15cf8c5de437b23fe473668d8a9fa21f89b4ac0a67a54f8c9daedc1ba

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page