Skip to main content

Prune your sklearn models.

Project description

scikit-prune

Prune your sklearn models.

Deep learning libraries offer pruning techniques to ensure that the models are lightweight when they are stored on disk. It's a technique that makes a lot of sense; you often don't need float64 numbers to represent the weights of a machine learning model.

It got me thinking, would such a technique also work in scikit-learn?

Enter scikit-prune

As a demo, let's say that we're dealing with a text classification use-case.

from sklearn.datasets import fetch_20newsgroups

text = fetch_20newsgroups()['data']

Then we might have a pipeline that fetches the sparse tf/idf features from this text and then turns these into a dense representation via SVD.

from sklearn.pipeline import make_pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer

pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(300))
pipe.fit(text)

Then we can choose to save this pipeline on disk, preferably via a system like skops.

from skops.io import dump

dump(pipe, "piper-orig.skops")

This results in a 275Mb file on disk, which is actually kind of big, and the most significant chunk of these megabytes are spent on the float64 numpy arrays that belong to the SVD object.

With this library, you can shrink that down a bit.

from skprune import prune 

dump(prune(pipe), "piper-lite.skops")

Now, the file is fair bit lighter, only 126Mb on disk. Which is a step in the right direction. You can get it down even further by saving it as a ZIP file which moves it closer to 41Mb.

Caveats

This technique can save a bunch of disk space for sure, but at least theoretically, it can also lead to some numerical mishaps when you try to apply the pruned pipeline. Always make sure that you check and evaluate the pruned pipeline before doing anything in production with it!

It's also good to remember that your results may certainly vary. In our example the TruncatedSVD component was the culprit because it was dealing with a very large internal matrix. If your pipeline doesn't have very large matrices, you probably won't get big savings in disk space.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

scikit-prune-0.1.0.tar.gz (3.6 kB view details)

Uploaded Source

Built Distribution

scikit_prune-0.1.0-py2.py3-none-any.whl (3.7 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file scikit-prune-0.1.0.tar.gz.

File metadata

  • Download URL: scikit-prune-0.1.0.tar.gz
  • Upload date:
  • Size: 3.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.4

File hashes

Hashes for scikit-prune-0.1.0.tar.gz
Algorithm Hash digest
SHA256 309a8c4f45f7872964a13d05a6169a8b9992c690de042ce9912691df4ecb742f
MD5 56affe93a8933f91ef13bd94e116c5e9
BLAKE2b-256 d25de71341d8bbb0b29696bbf2bd15234e557687551905055a2546705a3a6026

See more details on using hashes here.

File details

Details for the file scikit_prune-0.1.0-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for scikit_prune-0.1.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 481ee379129faf8608af42151d697fde69b90b86b21f77785a4ecb805777eed4
MD5 3501ba06f93cb602adb67b15584f21a2
BLAKE2b-256 f87519cb748e489d321edc913fdb3355385aecbffb5a552c302b3634b32f8e01

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page