Skip to main content

Framework that helps to train models, compare them and track parameters&metrics along the way.

Project description

Python 3.7 CodeFactor codecov

🌿 Trava ( initially stands for TrainValidation )

Framework that helps to train models, compare them and track parameters&metrics along the way. Works with tabular data only.

pip install trava

Compare models and keep track of metrics with ease!

While working on a project, we often experiment with different models looking at the same metrics. For example, we log those that can be represented as a single number, however some of them require graphs to make sense. It's also useful to save those metrics somewhere for future analysis, the list can go on.

So why not to use some unified interface for that?

Here is Trava's way:

1). Declare metrics you want to calculate:

# in this case, sk and sk_proba are just wrappers around sklearn's metrics
# but you can use any metric implementation you want
scorers = [
  sk_proba(log_loss),
  sk_proba(roc_auc_score),
  sk(recall_score),
  sk(precision_score),
]

2). What do you want to do with the metrics?

# let's log the metrics
logger_handler = LoggerHandler(scorers=scorers)

3). Initialize Trava

trava = TravaSV(results_handlers=[logger_handler])

4). Fit your model using Trava

# prepare your data
X_train, X_test, y_train, y_test = ...

split_result = SplitResult(X_train=X_train, 
                           y_train=y_train,
                           X_test=X_test,
                           y_test=y_test)

trava.fit_predict(raw_split_data=split_result,
                  model_type=GaussianNB, # we pass model class and parameters separately
                  model_init_params={},  # to be able to track them properly
                  model_id='gnb') # just a unique identifier for this model

fit_predict call does roughly the same as:

gnb = GaussianNB()
gnb.fit(split_result.X_train, split_result.y_train)
gnb.predict(split_result.X_test)

But now you don't care how the metrics you declared are calculated. You just get them in your console! Btw, those metrics definitely need to be improved. :]

Model evaluation nb
* Results for gnb model *
Train metrics:
log_loss:
16.755867191506482
roc_auc_score:
0.7746522424770221
recall_score:
0.10468384074941452
precision_score:
0.9122448979591836


Test metrics:
log_loss:
16.94514025416013
roc_auc_score:
0.829444814485889
recall_score:
0.026041666666666668
precision_score:
0.7692307692307693

After training multiple models you can get all metrics for all models by calling.

trava.results

Get the full picture and more examples by going through the guide notebook!

Built-in handlers:

  • LoggerHandler - logs metrics
  • PlotHandler - plots metrics
  • MetricsDictHandler - returns all metrics wrapped in a dict

Enable metrics autotracking. How cool is that?

Experiments tracking is a must in Data Science, so you shouldn't neglect that. You may integrate any tracking framework in Trava! Trava comes with MLFlow tracker ready-to-go. It can autotrack:

  • model's parameters
  • any metric
  • plots
  • serialized models

MLFlow example:

# get tracker's instance
tracker = MLFlowTracker(scorers=scorers)
# initialize Trava with it
trava = TravaSV(tracker=tracker)
# fit your model as before
trava.fit_predict(raw_split_data=split_result,
                  model_type=GaussianNB,
                  model_id='gnb')

Done. All model parameters and metrics are now tracked! Also supported tracking of:

  • cross-validation case with nested tracking
  • eval results for common boosting libraries ( XGBoost, LightGBM, CatBoost )

Checkout a detailed notebooks how to track metrics & parameters and plots & serialized models.

General information

  • highly customizable training & evaluation processes ( see trava.fit_predictor.py.FitPredictor class and its subclasses )
  • built-in train/test/validation split logic
  • common boosting libraries extensions ( for early-stopping with validation sets )
  • tracks model parameters, metrics, plots, serialized models. it's easy to integrate any tracking framework of your choice
  • you are also able to evaluate metrics after fit_predict call, if you forgot to add some metric
  • you are able to evaluate metrics even when your data and even a trained model are already unloaded ( depends on a metric used, true most of the times )
  • now only supervised learning problems are supported yet there is a potential to extend it to support unsupervised problems
  • unit-tested
  • I use it every day for my needs thus I care about the quality and reliability

Only sklearn-style model are supported for the time being. ( it uses fit, predict, predict_proba methods )

Requirements

pandas
numpy
python 3.7

It's also convenient to use the lib with sklearn ( e.g. for taking metrics from there. ). Also couple of extensions are based on sklearn classes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trava-0.2.12.tar.gz (36.6 kB view details)

Uploaded Source

Built Distribution

trava-0.2.12-py3-none-any.whl (50.1 kB view details)

Uploaded Python 3

File details

Details for the file trava-0.2.12.tar.gz.

File metadata

  • Download URL: trava-0.2.12.tar.gz
  • Upload date:
  • Size: 36.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.5

File hashes

Hashes for trava-0.2.12.tar.gz
Algorithm Hash digest
SHA256 e5a574ddcf1b7c8b6f18480a24c238f6b486c4d63784d4aec84fbf383bf7fca4
MD5 9cd0fc882b4a376e89583a7606a50a54
BLAKE2b-256 300affe3522fa0320a7d9cafcafc22021ceb4a396eefde12680d88e60147723f

See more details on using hashes here.

File details

Details for the file trava-0.2.12-py3-none-any.whl.

File metadata

  • Download URL: trava-0.2.12-py3-none-any.whl
  • Upload date:
  • Size: 50.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/3.10.0 pkginfo/1.7.0 requests/2.24.0 requests-toolbelt/0.9.1 tqdm/4.51.0 CPython/3.8.5

File hashes

Hashes for trava-0.2.12-py3-none-any.whl
Algorithm Hash digest
SHA256 7f903c59807630dd3269c0607047d9e9be830724a01d169fe790965f26358841
MD5 c393559f2e32df5f80776c47cf4a11ea
BLAKE2b-256 10a5d3997b0171b3a8a8d23e74a57cf2a34eab462f344b2cd1c06e6d98537353

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page