TabNet for fastai
This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (>=2.0) library. The original paper https://arxiv.org/pdf/1908.07442.pdf. The implementation is taken from here https://github.com/dreamquark-ai/tabnet
Install
pip install fast_tabnet
How to use
model = TabNetModel(emb_szs, n_cont, out_sz, embed_p=0., y_range=None, n_d=8, n_a=8, n_steps=3, gamma=1.3, n_independent=2, n_shared=2, epsilon=1e-15, virtual_batch_size=128, momentum=0.02)
Parameters emb_szs, n_cont, out_sz, embed_p, y_range
are the same as for fastai TabularModel.
- n_d : int
Dimension of the prediction layer (usually between 4 and 64)
- n_a : int
Dimension of the attention layer (usually between 4 and 64)
- n_steps: int
Number of sucessive steps in the newtork (usually betwenn 3 and 10)
- gamma : float
Float above 1, scaling factor for attention updates (usually betwenn 1.0 to 2.0)
- momentum : float
Float value between 0 and 1 which will be used for momentum in all batch norm
- n_independent : int
Number of independent GLU layer in each GLU block (default 2)
- n_shared : int
Number of independent GLU layer in each GLU block (default 2)
- epsilon: float
Avoid log(0), this should be kept very low
Example
Below is an example from fastai library, but the model in use is TabNet
from fastai2.basics import *
from fastai2.tabular.all import *
from fast_tabnet.core import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
df_main,df_test = df.iloc[:10000].copy(),df.iloc[10000:].copy()
df_main.head()
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
|
age |
workclass |
fnlwgt |
education |
education-num |
marital-status |
occupation |
relationship |
race |
sex |
capital-gain |
capital-loss |
hours-per-week |
native-country |
salary |
0 |
49 |
Private |
101320 |
Assoc-acdm |
12.0 |
Married-civ-spouse |
NaN |
Wife |
White |
Female |
0 |
1902 |
40 |
United-States |
>=50k |
1 |
44 |
Private |
236746 |
Masters |
14.0 |
Divorced |
Exec-managerial |
Not-in-family |
White |
Male |
10520 |
0 |
45 |
United-States |
>=50k |
2 |
38 |
Private |
96185 |
HS-grad |
NaN |
Divorced |
NaN |
Unmarried |
Black |
Female |
0 |
0 |
32 |
United-States |
<50k |
3 |
38 |
Self-emp-inc |
112847 |
Prof-school |
15.0 |
Married-civ-spouse |
Prof-specialty |
Husband |
Asian-Pac-Islander |
Male |
0 |
0 |
40 |
United-States |
>=50k |
4 |
42 |
Self-emp-not-inc |
82297 |
7th-8th |
NaN |
Married-civ-spouse |
Other-service |
Wife |
Black |
Female |
0 |
0 |
50 |
United-States |
<50k |
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter()(range_of(df_main))
to = TabularPandas(df_main, procs, cat_names, cont_names, y_names="salary", splits=splits)
dbch = to.dataloaders()
dbch.valid.show_batch()
|
workclass |
education |
marital-status |
occupation |
relationship |
race |
age_na |
fnlwgt_na |
education-num_na |
age |
fnlwgt |
education-num |
salary |
0 |
Local-gov |
HS-grad |
Married-civ-spouse |
Craft-repair |
Husband |
Black |
False |
False |
False |
45.000000 |
556652.007499 |
9.0 |
>=50k |
1 |
Private |
Bachelors |
Never-married |
Sales |
Not-in-family |
White |
False |
False |
False |
29.000000 |
176683.000072 |
13.0 |
>=50k |
2 |
Private |
Bachelors |
Married-civ-spouse |
Exec-managerial |
Husband |
White |
False |
False |
False |
29.000000 |
194939.999936 |
13.0 |
<50k |
3 |
Private |
HS-grad |
Married-civ-spouse |
Transport-moving |
Husband |
White |
False |
False |
False |
29.000000 |
52635.998841 |
9.0 |
<50k |
4 |
State-gov |
Some-college |
Married-civ-spouse |
Machine-op-inspct |
Husband |
White |
False |
False |
False |
49.000000 |
122177.000557 |
10.0 |
>=50k |
5 |
Private |
12th |
Married-civ-spouse |
Machine-op-inspct |
Other-relative |
Other |
False |
False |
False |
28.000000 |
158737.000048 |
8.0 |
<50k |
6 |
Private |
HS-grad |
Married-civ-spouse |
Machine-op-inspct |
Husband |
White |
False |
False |
False |
55.999999 |
192868.999992 |
9.0 |
>=50k |
7 |
Self-emp-not-inc |
HS-grad |
Married-civ-spouse |
Craft-repair |
Husband |
White |
False |
False |
False |
56.999999 |
65080.002276 |
9.0 |
>=50k |
8 |
Local-gov |
Masters |
Married-civ-spouse |
Prof-specialty |
Husband |
White |
False |
False |
True |
50.000000 |
145165.999578 |
10.0 |
>=50k |
9 |
Private |
Assoc-voc |
Never-married |
Tech-support |
Not-in-family |
White |
False |
False |
False |
35.000000 |
186034.999925 |
11.0 |
<50k |
to_tst = to.new(df_test)
to_tst.process()
to_tst.all_cols.head()
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
|
workclass |
education |
marital-status |
occupation |
relationship |
race |
age_na |
fnlwgt_na |
education-num_na |
age |
fnlwgt |
education-num |
salary |
10000 |
5 |
10 |
3 |
2 |
1 |
2 |
1 |
1 |
1 |
0.456238 |
1.346622 |
1.164335 |
0 |
10001 |
5 |
12 |
3 |
15 |
1 |
4 |
1 |
1 |
1 |
-0.930752 |
1.259253 |
-0.419996 |
0 |
10002 |
5 |
2 |
1 |
9 |
2 |
5 |
1 |
1 |
1 |
1.040233 |
0.152193 |
-1.212162 |
0 |
10003 |
5 |
12 |
7 |
2 |
5 |
5 |
1 |
1 |
1 |
0.529237 |
-0.282632 |
-0.419996 |
0 |
10004 |
6 |
9 |
3 |
5 |
1 |
5 |
1 |
1 |
1 |
0.748235 |
1.449564 |
0.372169 |
1 |
emb_szs = get_emb_sz(to); print(emb_szs)
[(10, 6), (17, 8), (8, 5), (16, 8), (7, 5), (6, 4), (2, 2), (2, 2), (3, 3)]
That's the use of the model
model = TabNetModel(emb_szs, len(to.cont_names), 1, n_d=8, n_a=32, n_steps=1);
opt_func = partial(Adam, wd=0.01, eps=1e-5)
learn = Learner(dbch, model, MSELossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])
learn.lr_find()
learn.fit_one_cycle(10)
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.161420 |
0.163181 |
0.757500 |
00:04 |
1 |
0.140478 |
0.127033 |
0.757500 |
00:04 |
2 |
0.132842 |
0.117864 |
0.757500 |
00:04 |
3 |
0.126220 |
0.115803 |
0.757500 |
00:04 |
4 |
0.125338 |
0.117127 |
0.757500 |
00:03 |
5 |
0.123562 |
0.119050 |
0.757500 |
00:04 |
6 |
0.121530 |
0.117025 |
0.757500 |
00:04 |
7 |
0.116976 |
0.114524 |
0.757500 |
00:04 |
8 |
0.113542 |
0.114590 |
0.757500 |
00:04 |
9 |
0.111071 |
0.114707 |
0.757500 |
00:04 |
Example with Bayesian Optimization
I like to tune hyperparameters for tabular models with Bayesian Optimization. You can optimize directly your metric using this approach if the metric is sensitive enough (in our example it is not and we use validation loss instead). Also, you should create the second validation set, because you will use the first as a training set for Bayesian Optimization.
You may need to install the optimizer pip install bayesian-optimization
from functools import lru_cache
# The function we'll optimize
@lru_cache(1000)
def get_accuracy(n_d:Int, n_a:Int, n_steps:Int):
model = TabNetModel(emb_szs, len(to.cont_names), 1, n_d=n_d, n_a=n_a, n_steps=n_steps);
learn = Learner(dbch, model, MSELossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])
learn.fit_one_cycle(5)
return -float(learn.validate(dl=learn.dls.valid)[0])
This implementation of Bayesian Optimization doesn't work naturally with descreet values. That's why we use wrapper with lru_cache
.
def fit_accuracy(pow_n_d, pow_n_a, pow_n_steps):
pow_n_d, pow_n_a, pow_n_steps = map(int, (pow_n_d, pow_n_a, pow_n_steps))
return get_accuracy(2**pow_n_d, 2**pow_n_a, 2**pow_n_steps)
from bayes_opt import BayesianOptimization
# Bounded region of parameter space
pbounds = {'pow_n_d': (0, 8), 'pow_n_a': (0, 8), 'pow_n_steps': (0, 4)}
optimizer = BayesianOptimization(
f=fit_accuracy,
pbounds=pbounds,
)
optimizer.maximize(
init_points=15,
n_iter=100,
)
| iter | target | pow_n_a | pow_n_d | pow_n_... |
-------------------------------------------------------------
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
1.376397 |
0.227991 |
0.757500 |
00:07 |
1 |
0.307311 |
0.188101 |
0.757500 |
00:06 |
2 |
0.192308 |
0.174029 |
0.757500 |
00:06 |
3 |
0.180625 |
0.168215 |
0.757500 |
00:07 |
4 |
0.171093 |
0.168311 |
0.757500 |
00:07 |
| [0m 1 [0m | [0m-0.1683 [0m | [0m 2.099 [0m | [0m 2.108 [0m | [0m 2.532 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.156191 |
0.145624 |
0.757500 |
00:04 |
1 |
0.135885 |
0.131468 |
0.757500 |
00:04 |
2 |
0.124489 |
0.116068 |
0.757500 |
00:04 |
3 |
0.120778 |
0.115556 |
0.757500 |
00:04 |
4 |
0.118399 |
0.114798 |
0.757500 |
00:04 |
| [95m 2 [0m | [95m-0.1148 [0m | [95m 5.582 [0m | [95m 0.5914 [0m | [95m 0.394 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.732101 |
0.201414 |
0.757500 |
00:09 |
1 |
0.213341 |
0.182902 |
0.757500 |
00:09 |
2 |
0.157790 |
0.154676 |
0.757500 |
00:09 |
3 |
0.143525 |
0.134003 |
0.757500 |
00:09 |
4 |
0.137171 |
0.128810 |
0.757500 |
00:09 |
| [0m 3 [0m | [0m-0.1288 [0m | [0m 0.6418 [0m | [0m 3.424 [0m | [0m 3.649 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.255437 |
0.176615 |
0.757500 |
00:06 |
1 |
0.164086 |
0.158516 |
0.757500 |
00:07 |
2 |
0.149184 |
0.139764 |
0.757500 |
00:06 |
3 |
0.137243 |
0.126479 |
0.757500 |
00:06 |
4 |
0.132500 |
0.125504 |
0.757500 |
00:06 |
| [0m 4 [0m | [0m-0.1255 [0m | [0m 6.121 [0m | [0m 1.372 [0m | [0m 2.897 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.834591 |
0.252279 |
0.757500 |
00:06 |
1 |
0.233243 |
0.190753 |
0.757500 |
00:06 |
2 |
0.174514 |
0.163240 |
0.757500 |
00:06 |
3 |
0.160865 |
0.149085 |
0.757500 |
00:06 |
4 |
0.153380 |
0.142670 |
0.757500 |
00:06 |
| [0m 5 [0m | [0m-0.1427 [0m | [0m 7.183 [0m | [0m 5.46 [0m | [0m 2.131 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.280760 |
0.184326 |
0.757500 |
00:05 |
1 |
0.151150 |
0.149422 |
0.757500 |
00:05 |
2 |
0.136892 |
0.126405 |
0.757500 |
00:05 |
3 |
0.129048 |
0.124096 |
0.757500 |
00:05 |
4 |
0.129486 |
0.122428 |
0.757500 |
00:04 |
| [0m 6 [0m | [0m-0.1224 [0m | [0m 0.5754 [0m | [0m 2.298 [0m | [0m 1.447 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
2.923816 |
0.290585 |
0.757500 |
00:09 |
1 |
0.635441 |
0.237105 |
0.757500 |
00:09 |
2 |
0.272063 |
0.170947 |
0.757500 |
00:09 |
3 |
0.179265 |
0.156215 |
0.757500 |
00:09 |
4 |
0.159060 |
0.151041 |
0.757500 |
00:09 |
| [0m 7 [0m | [0m-0.151 [0m | [0m 6.365 [0m | [0m 7.881 [0m | [0m 3.652 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
1.436597 |
0.213113 |
0.757500 |
00:09 |
1 |
0.350264 |
0.189146 |
0.757500 |
00:09 |
2 |
0.187943 |
0.162571 |
0.757500 |
00:09 |
3 |
0.165730 |
0.154995 |
0.757500 |
00:09 |
4 |
0.155386 |
0.149732 |
0.757500 |
00:09 |
| [0m 8 [0m | [0m-0.1497 [0m | [0m 5.544 [0m | [0m 5.838 [0m | [0m 3.925 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.430938 |
0.227863 |
0.757500 |
00:09 |
1 |
0.209979 |
0.177186 |
0.757500 |
00:09 |
2 |
0.179570 |
0.164046 |
0.757500 |
00:09 |
3 |
0.170003 |
0.161813 |
0.757500 |
00:09 |
4 |
0.168120 |
0.159528 |
0.757500 |
00:10 |
| [0m 9 [0m | [0m-0.1595 [0m | [0m 4.231 [0m | [0m 1.842 [0m | [0m 3.959 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.196750 |
0.168031 |
0.757500 |
00:04 |
1 |
0.155173 |
0.152989 |
0.757500 |
00:04 |
2 |
0.144540 |
0.126592 |
0.757500 |
00:04 |
3 |
0.133649 |
0.126462 |
0.757500 |
00:04 |
4 |
0.124242 |
0.119457 |
0.757500 |
00:04 |
| [0m 10 [0m | [0m-0.1195 [0m | [0m 7.513 [0m | [0m 6.718 [0m | [0m 0.3416 [0m |
<div>
<style>
/* Turns off some styling */
progress {
/* gets rid of default border in Firefox and Opera. */
border: none;
/* Needs to be in here for Safari polyfill so background images work as expected. */
background-size: auto;
}
.progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
background: #F44336;
}
</style>
<progress value='0' class='' max='5', style='width:300px; height:20px; vertical-align: middle;'></progress>
0.00% [0/5 00:00<00:00]
</div>
epoch |
train_loss |
valid_loss |
accuracy |
time |
<div>
<style>
/* Turns off some styling */
progress {
/* gets rid of default border in Firefox and Opera. */
border: none;
/* Needs to be in here for Safari polyfill so background images work as expected. */
background-size: auto;
}
.progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
background: #F44336;
}
</style>
<progress value='70' class='' max='125', style='width:300px; height:20px; vertical-align: middle;'></progress>
56.00% [70/125 00:02<00:01 0.3725]
</div>
optimizer.max