No project description provided
Project description
GFS Network
Gumbel Feature Selection Network is a deep learning model that can be used to select the most important features from a given dataset. The model is based on the Gumbel-Sigmoid distribution.
Installation
To install the package, you can use pip:
pip install gfs_network
Usage examples
Basic usage
from gfs_network import GFSNetwork
from sklearn.datasets import load_breast_cancer
breast = load_breast_cancer()
X = breast.data
y = breast.target
gfs = GFSNetwork()
gfs.fit(X, y)
print(gfs.support_)
print(gfs.scores_)
Performance verification
from gfs_network import GFSNetwork
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import balanced_accuracy_score
DEVICE = "cpu"
breast = load_breast_cancer()
X = breast.data
y = breast.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = RandomForestClassifier(random_state=42)
clf.train(X_train, y_train)
orig_score = balanced_accuracy_score(y_test, clf.predict(X_test))
print(f"Original score: {orig_score:.3f}. Original features: {X.shape[1]}")
# Original score: 0.958. Original features: 30
gfs = GFSNetwork(verbose=True, device=DEVICE)
gfs.fit(X_train, y_train)
X_transformed = gfs.transform(X_train)
X_test_transformed = gfs.transform(X_test)
clf.fit(X_transformed, y_train)
y_pred = clf.predict(X_test_transformed)
score = balanced_accuracy_score(y_test, y_pred)
logger.info(f"Score after feature selection: {score}. Selected features: {sum(gfs.support_)}")
# Score after feature selection: 0.97. Selected features: 9
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
gfs_network-0.1.5.tar.gz
(4.1 kB
view hashes)
Built Distribution
Close
Hashes for gfs_network-0.1.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 75a2a0a0c9a751c5927aed5bbfcb8bf0c943f973bbca732582b291472457da98 |
|
MD5 | 555a192d4dd0316fe53e100d7d44f549 |
|
BLAKE2b-256 | d4c297b9d84c563151d04d0f0d9686358d9587d6c94cd24d5654f8a744391206 |