Skip to main content

An end-to-end feature selection distribution with linear runtime(number of features) complexity.

Project description

GFS Network

Gumbel Feature Selection Network is a deep learning model that can be used to select the most important features from a given dataset. The model is based on the Gumbel-Sigmoid distribution.

Installation

To install the package, you can use pip:

pip install gfs_network

Usage examples

Basic usage

from gfs_network import GFSNetwork
from sklearn.datasets import load_breast_cancer

breast = load_breast_cancer()
X = breast.data
y = breast.target

gfs = GFSNetwork()
X = gfs.fit_transform(X, y)

print(gfs.support_)
print(gfs.scores_)

Performance verification

from gfs_network import GFSNetwork
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import balanced_accuracy_score

DEVICE = "cpu"

breast = load_breast_cancer()
X = breast.data
y = breast.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = RandomForestClassifier(random_state=42)
clf.train(X_train, y_train)
orig_score = balanced_accuracy_score(y_test, clf.predict(X_test))

print(f"Original score: {orig_score:.3f}. Original features: {X.shape[1]}")
# Original score: 0.958. Original features: 30

gfs = GFSNetwork(verbose=True, device=DEVICE)
gfs.fit(X_train, y_train)

X_transformed = gfs.transform(X_train)
X_test_transformed = gfs.transform(X_test)

clf.fit(X_transformed, y_train)
y_pred = clf.predict(X_test_transformed)
score = balanced_accuracy_score(y_test, y_pred)
logger.info(f"Score after feature selection: {score}. Selected features: {sum(gfs.support_)}")
# Score after feature selection: 0.958. Selected features: 3

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gfs_network-0.2.9.tar.gz (4.4 kB view details)

Uploaded Source

Built Distribution

gfs_network-0.2.9-py3-none-any.whl (5.6 kB view details)

Uploaded Python 3

File details

Details for the file gfs_network-0.2.9.tar.gz.

File metadata

  • Download URL: gfs_network-0.2.9.tar.gz
  • Upload date:
  • Size: 4.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.9.0 Linux/5.10.0-26-amd64

File hashes

Hashes for gfs_network-0.2.9.tar.gz
Algorithm Hash digest
SHA256 0fd6e29cd3a5111df64df33b1a6b6cb5efa61ec3e1f4b14b3d7e98e57e3eda8f
MD5 e70854b1f33c512efd284581ad6b3e75
BLAKE2b-256 388196d23780457278708b77f86905f6d3546139cde5c481f237ba6f7a43c7e5

See more details on using hashes here.

File details

Details for the file gfs_network-0.2.9-py3-none-any.whl.

File metadata

  • Download URL: gfs_network-0.2.9-py3-none-any.whl
  • Upload date:
  • Size: 5.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.9.0 Linux/5.10.0-26-amd64

File hashes

Hashes for gfs_network-0.2.9-py3-none-any.whl
Algorithm Hash digest
SHA256 30ff970b2af9ded41beb203d4f81f4b43bc01e6f2d9f348c882d774294b506e9
MD5 988e168097ea4e2986e8ccd80d06d035
BLAKE2b-256 f7b71539867666baa8238930685b195ca76399f9c6bb154573cad54af0037dad

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page