TabNet for fastai
Project description
TabNet for fastai
This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (>=2.0) library. The original paper https://arxiv.org/pdf/1908.07442.pdf. The implementation is taken from here https://github.com/dreamquark-ai/tabnet
Install
pip install fast_tabnet
How to use
model = TabNetModel(emb_szs, n_cont, out_sz, embed_p=0., y_range=None, n_d=8, n_a=8, n_steps=3, gamma=1.3, n_independent=2, n_shared=2, epsilon=1e-15, virtual_batch_size=128, momentum=0.02)
Parameters emb_szs, n_cont, out_sz, embed_p, y_range
are the same as for fastai TabularModel.
- n_d : int Dimension of the prediction layer (usually between 4 and 64)
- n_a : int Dimension of the attention layer (usually between 4 and 64)
- n_steps: int Number of sucessive steps in the newtork (usually betwenn 3 and 10)
- gamma : float Float above 1, scaling factor for attention updates (usually betwenn 1.0 to 2.0)
- momentum : float Float value between 0 and 1 which will be used for momentum in all batch norm
- n_independent : int Number of independent GLU layer in each GLU block (default 2)
- n_shared : int Number of independent GLU layer in each GLU block (default 2)
- epsilon: float Avoid log(0), this should be kept very low
Example
Below is an example from fastai library, but the model in use is TabNet
from fastai2.basics import *
from fastai2.tabular.all import *
from fast_tabnet.core import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()
df_main.head()
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
age | workclass | fnlwgt | education | education-num | marital-status | occupation | relationship | race | sex | capital-gain | capital-loss | hours-per-week | native-country | salary | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 49 | Private | 101320 | Assoc-acdm | 12.0 | Married-civ-spouse | NaN | Wife | White | Female | 0 | 1902 | 40 | United-States | >=50k |
1 | 44 | Private | 236746 | Masters | 14.0 | Divorced | Exec-managerial | Not-in-family | White | Male | 10520 | 0 | 45 | United-States | >=50k |
2 | 38 | Private | 96185 | HS-grad | NaN | Divorced | NaN | Unmarried | Black | Female | 0 | 0 | 32 | United-States | <50k |
3 | 38 | Self-emp-inc | 112847 | Prof-school | 15.0 | Married-civ-spouse | Prof-specialty | Husband | Asian-Pac-Islander | Male | 0 | 0 | 40 | United-States | >=50k |
4 | 42 | Self-emp-not-inc | 82297 | 7th-8th | NaN | Married-civ-spouse | Other-service | Wife | Black | Female | 0 | 0 | 50 | United-States | <50k |
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter()(range_of(df_main))
to = TabularPandas(df_main, procs, cat_names, cont_names, y_names="salary", splits=splits)
dbch = to.dataloaders()
dbch.valid.show_batch()
workclass | education | marital-status | occupation | relationship | race | age_na | fnlwgt_na | education-num_na | age | fnlwgt | education-num | salary | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Private | Prof-school | Married-civ-spouse | Prof-specialty | Husband | White | False | False | False | 35.000000 | 374524.001986 | 15.0 | >=50k |
1 | Federal-gov | Assoc-acdm | Married-civ-spouse | Adm-clerical | Husband | White | False | False | False | 44.000000 | 251305.001512 | 12.0 | >=50k |
2 | ? | 11th | Never-married | ? | Own-child | White | False | False | False | 17.000000 | 297117.002500 | 7.0 | <50k |
3 | ? | HS-grad | Married-civ-spouse | ? | Husband | White | False | False | False | 71.999999 | 117017.001678 | 9.0 | <50k |
4 | ? | Some-college | Never-married | ? | Own-child | White | False | False | False | 20.000001 | 95988.999714 | 10.0 | <50k |
5 | Private | Some-college | Divorced | Craft-repair | Unmarried | Black | False | False | False | 39.000000 | 214117.000147 | 10.0 | <50k |
6 | Private | Assoc-acdm | Married-civ-spouse | Prof-specialty | Husband | White | False | False | False | 30.000000 | 48520.003341 | 12.0 | <50k |
7 | Private | Some-college | Divorced | #na# | Unmarried | Black | False | False | False | 31.000000 | 377374.000758 | 10.0 | <50k |
8 | Local-gov | Prof-school | Never-married | Prof-specialty | Not-in-family | White | False | False | False | 53.000000 | 131258.000220 | 15.0 | >=50k |
9 | Private | Masters | Separated | Exec-managerial | Unmarried | White | False | False | False | 44.000000 | 79863.997225 | 14.0 | <50k |
to_tst = to.new(df_test)
to_tst.process()
to_tst.all_cols.head()
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
workclass | education | marital-status | occupation | relationship | race | age_na | fnlwgt_na | education-num_na | age | fnlwgt | education-num | salary | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
10000 | 5 | 10 | 3 | 2 | 1 | 2 | 1 | 1 | 1 | 0.472456 | 1.356619 | 1.173511 | 0 |
10001 | 5 | 12 | 3 | 15 | 1 | 4 | 1 | 1 | 1 | -0.927933 | 1.268716 | -0.423170 | 0 |
10002 | 5 | 2 | 1 | 9 | 2 | 5 | 1 | 1 | 1 | 1.062094 | 0.154883 | -1.221511 | 0 |
10003 | 5 | 12 | 7 | 2 | 5 | 5 | 1 | 1 | 1 | 0.546161 | -0.282602 | -0.423170 | 0 |
10004 | 6 | 9 | 3 | 5 | 1 | 5 | 1 | 1 | 1 | 0.767275 | 1.460190 | 0.375170 | 1 |
emb_szs = get_emb_sz(to); print(emb_szs)
[(10, 6), (17, 8), (8, 5), (16, 8), (7, 5), (6, 4), (2, 2), (2, 2), (3, 3)]
That's the use of the model
model = TabNetModel(emb_szs, len(to.cont_names), 1, n_d=8, n_a=32, n_steps=1);
opt_func = partial(Adam, wd=0.01, eps=1e-5)
learn = Learner(dbch, model, MSELossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])
learn.lr_find()
learn.fit_one_cycle(10)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.166065 | 0.146415 | 0.765500 | 00:04 |
1 | 0.139964 | 0.131110 | 0.765500 | 00:04 |
2 | 0.136636 | 0.122154 | 0.765500 | 00:04 |
3 | 0.131105 | 0.125905 | 0.765500 | 00:04 |
4 | 0.130018 | 0.121818 | 0.765500 | 00:04 |
5 | 0.125062 | 0.116067 | 0.765500 | 00:04 |
6 | 0.120265 | 0.115156 | 0.765500 | 00:04 |
7 | 0.118240 | 0.112878 | 0.765500 | 00:04 |
8 | 0.115416 | 0.111366 | 0.765500 | 00:04 |
9 | 0.113975 | 0.111448 | 0.765500 | 00:04 |
Example with Bayesian Optimization
I like to tune hyperparameters for tabular models with Bayesian Optimization. You can optimize directly your metric using this approach if the metric is sensitive enough (in our example it is not and we use validation loss instead). Also, you should create the second validation set, because you will use the first as a training set for Bayesian Optimization.
You may need to install the optimizer pip install bayesian-optimization
from functools import lru_cache
# The function we'll optimize
@lru_cache(1000)
def get_accuracy(n_d:Int, n_a:Int, n_steps:Int):
model = TabNetModel(emb_szs, len(to.cont_names), 1, n_d=int(n_d), n_a=int(n_a), n_steps=int(n_steps));
learn = Learner(dbch, model, MSELossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])
learn.fit_one_cycle(5)
return -float(learn.validate(dl=learn.dls.valid)[0])
This implementation of Bayesian Optimization doesn't work naturally with descreet values. That's why we use wrapper with lru_cache
.
def fit_accuracy(pow_n_d, pow_n_a, pow_n_steps):
return get_accuracy(round(2**pow_n_d), round(2**pow_n_a), round(2**pow_n_steps))
from bayes_opt import BayesianOptimization
# Bounded region of parameter space
pbounds = {'pow_n_d': (0, 8), 'pow_n_a': (0, 8), 'pow_n_staps': (0, 4)}
optimizer = BayesianOptimization(
f=fit_accuracy,
pbounds=pbounds,
)
optimizer.maximize(
init_points=15,
n_iter=100,
)
optimizer.max
{'target': -0.11236412078142166,
'params': {'pow_n_a': 2.5840359360205936,
'pow_n_d': 2.442317935141724,
'pow_n_staps': 0.0}}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for fast_tabnet-0.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d92e879e0c6108678151be3b6fd19d7ea0b3c718e0ba2d3cbf0ec037189ed2e8 |
|
MD5 | 51669cde2f7edeecab3316ec9d2d815d |
|
BLAKE2b-256 | 2087249fc9803d989b82c8078655c6021a9ef2f9e06a0dc6e4aa60dfc68428a7 |