No project description provided
Project description
VisTabNet
This package introduces VisTabNet - Vision Transformer-based Tabular Data Classifier.
Usage
from vistabnet import VisTabNetClassifier
X_train, y_train, X_test, y_test = ... # Load your data here. Y should be label encoded, not one-hot encoded.
model = VisTabNetClassifier(input_features=X_train.shape[1], classes=len(np.unique(y_train)), device="cuda:1")
model.fit(X_train, y_train, eval_X=X_test, eval_y=y_test)
y_pred = model.predict(X_test)
acc = balanced_accuracy_score(y_test_, y_pred)
print(f"Balanced accuracy: {acc}")
Installation
pip install vistabnet
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
vistabnet-0.1.1.tar.gz
(4.9 kB
view hashes)
Built Distribution
Close
Hashes for vistabnet-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 62a9328a81f233b952ffd66ebaa87231054d4a7b45975a83e2fc345c1095a450 |
|
MD5 | 9d4369a580e60101d71f0a0a7371098d |
|
BLAKE2b-256 | 5311ba664258127809f738386480b1cc32a2fb442fa5009b88588170703cb756 |