TabNet for fastai
This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (>=2.0) library. The original paper https://arxiv.org/pdf/1908.07442.pdf.
Install
pip install fast_tabnet
How to use
model = TabNetModel(emb_szs, n_cont, out_sz, embed_p=0., y_range=None, n_d=8, n_a=8, n_steps=3, gamma=1.5, n_independent=2, n_shared=2, epsilon=1e-15, virtual_batch_size=128, momentum=0.02)
Parameters emb_szs, n_cont, out_sz, embed_p, y_range
are the same as for fastai TabularModel.
- n_d : int
Dimension of the prediction layer (usually between 4 and 64)
- n_a : int
Dimension of the attention layer (usually between 4 and 64)
- n_steps: int
Number of sucessive steps in the newtork (usually betwenn 3 and 10)
- gamma : float
Float above 1, scaling factor for attention updates (usually betwenn 1.0 to 2.0)
- momentum : float
Float value between 0 and 1 which will be used for momentum in all batch norm
- n_independent : int
Number of independent GLU layer in each GLU block (default 2)
- n_shared : int
Number of independent GLU layer in each GLU block (default 2)
- epsilon: float
Avoid log(0), this should be kept very low
Example
Below is an example from fastai library, but the model in use is TabNet
from fastai.basics import *
from fastai.tabular.all import *
from fast_tabnet.core import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
df_main,df_test = df.iloc[:-1000].copy(),df.iloc[-1000:].copy()
df_main.head()
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
|
age |
workclass |
fnlwgt |
education |
education-num |
marital-status |
occupation |
relationship |
race |
sex |
capital-gain |
capital-loss |
hours-per-week |
native-country |
salary |
0 |
49 |
Private |
101320 |
Assoc-acdm |
12.0 |
Married-civ-spouse |
NaN |
Wife |
White |
Female |
0 |
1902 |
40 |
United-States |
>=50k |
1 |
44 |
Private |
236746 |
Masters |
14.0 |
Divorced |
Exec-managerial |
Not-in-family |
White |
Male |
10520 |
0 |
45 |
United-States |
>=50k |
2 |
38 |
Private |
96185 |
HS-grad |
NaN |
Divorced |
NaN |
Unmarried |
Black |
Female |
0 |
0 |
32 |
United-States |
<50k |
3 |
38 |
Self-emp-inc |
112847 |
Prof-school |
15.0 |
Married-civ-spouse |
Prof-specialty |
Husband |
Asian-Pac-Islander |
Male |
0 |
0 |
40 |
United-States |
>=50k |
4 |
42 |
Self-emp-not-inc |
82297 |
7th-8th |
NaN |
Married-civ-spouse |
Other-service |
Wife |
Black |
Female |
0 |
0 |
50 |
United-States |
<50k |
cat_names = ['workclass', 'education', 'marital-status', 'occupation',
'relationship', 'race', 'native-country', 'sex']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter()(range_of(df_main))
to = TabularPandas(df_main, procs, cat_names, cont_names, y_names="salary",
y_block = CategoryBlock(), splits=splits)
dls = to.dataloaders(bs=32)
dls.valid.show_batch()
|
workclass |
education |
marital-status |
occupation |
relationship |
race |
native-country |
sex |
education-num_na |
age |
fnlwgt |
education-num |
salary |
0 |
Private |
HS-grad |
Married-civ-spouse |
Other-service |
Wife |
White |
United-States |
Female |
False |
39.000000 |
196673.000115 |
9.0 |
<50k |
1 |
Private |
HS-grad |
Married-civ-spouse |
Craft-repair |
Husband |
White |
United-States |
Male |
False |
32.000000 |
198067.999771 |
9.0 |
<50k |
2 |
State-gov |
HS-grad |
Never-married |
Adm-clerical |
Own-child |
White |
United-States |
Female |
False |
18.999999 |
176633.999931 |
9.0 |
<50k |
3 |
Private |
Some-college |
Married-civ-spouse |
Prof-specialty |
Husband |
White |
United-States |
Male |
False |
67.999999 |
107626.998490 |
10.0 |
<50k |
4 |
Private |
Masters |
Never-married |
Exec-managerial |
Not-in-family |
Black |
United-States |
Male |
False |
29.000000 |
214925.000260 |
14.0 |
<50k |
5 |
Private |
HS-grad |
Married-civ-spouse |
Priv-house-serv |
Wife |
White |
United-States |
Female |
False |
22.000000 |
200109.000126 |
9.0 |
<50k |
6 |
Private |
Some-college |
Never-married |
Sales |
Own-child |
White |
United-States |
Female |
False |
18.000000 |
60980.998429 |
10.0 |
<50k |
7 |
Private |
Some-college |
Separated |
Adm-clerical |
Not-in-family |
White |
United-States |
Female |
False |
28.000000 |
334367.998199 |
10.0 |
<50k |
8 |
Private |
11th |
Married-civ-spouse |
Transport-moving |
Husband |
White |
United-States |
Male |
False |
49.000000 |
123584.001097 |
7.0 |
<50k |
9 |
Private |
Masters |
Never-married |
Prof-specialty |
Not-in-family |
White |
United-States |
Female |
False |
26.000000 |
397316.999922 |
14.0 |
<50k |
to_tst = to.new(df_test)
to_tst.process()
to_tst.all_cols.head()
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
|
workclass |
education |
marital-status |
occupation |
relationship |
race |
native-country |
sex |
education-num_na |
age |
fnlwgt |
education-num |
salary |
31561 |
5 |
2 |
5 |
9 |
3 |
3 |
40 |
2 |
1 |
-1.505833 |
-0.559418 |
-1.202170 |
0 |
31562 |
2 |
12 |
5 |
2 |
5 |
3 |
40 |
1 |
1 |
-1.432653 |
0.421241 |
-0.418032 |
0 |
31563 |
5 |
7 |
3 |
4 |
1 |
5 |
40 |
2 |
1 |
-0.115406 |
0.132868 |
-1.986307 |
0 |
31564 |
8 |
12 |
3 |
9 |
1 |
5 |
40 |
2 |
1 |
1.494561 |
0.749805 |
-0.418032 |
0 |
31565 |
1 |
12 |
1 |
1 |
5 |
3 |
40 |
2 |
1 |
-0.481308 |
7.529798 |
-0.418032 |
0 |
emb_szs = get_emb_sz(to)
That's the use of the model
model = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=8, n_a=8, n_steps=5, mask_type='entmax');
learn = Learner(dls, model, CrossEntropyLossFlat(), opt_func=Adam, lr=3e-2, metrics=[accuracy])
learn.lr_find()
SuggestedLRs(lr_min=0.2754228591918945, lr_steep=1.9054607491852948e-06)
learn.fit_one_cycle(5)
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.446274 |
0.414451 |
0.817015 |
00:30 |
1 |
0.380002 |
0.393030 |
0.818916 |
00:30 |
2 |
0.371149 |
0.359802 |
0.832066 |
00:30 |
3 |
0.349027 |
0.352255 |
0.835868 |
00:30 |
4 |
0.355339 |
0.349360 |
0.836819 |
00:30 |
Tabnet interpretability
# feature importance for 2k rows
dl = learn.dls.test_dl(df.iloc[:2000], bs=256)
feature_importances = tabnet_feature_importances(learn.model, dl)
# per sample interpretation
dl = learn.dls.test_dl(df.iloc[:20], bs=20)
res_explain, res_masks = tabnet_explain(learn.model, dl)
plt.xticks(rotation='vertical')
plt.bar(dl.x_names, feature_importances, color='g')
plt.show()
def plot_explain(masks, lbls, figsize=(12,12)):
"Plots masks with `lbls` (`dls.x_names`)"
fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
plt.yticks(np.arange(0, len(masks), 1.0))
plt.xticks(np.arange(0, len(masks[0]), 1.0))
ax.set_xticklabels(lbls, rotation=90)
plt.ylabel('Sample Number')
plt.xlabel('Variable')
plt.imshow(masks)
plot_explain(res_explain, dl.x_names)
Hyperparameter search with Bayesian Optimization
If your dataset isn't huge you can tune hyperparameters for tabular models with Bayesian Optimization. You can optimize directly your metric using this approach if the metric is sensitive enough (in our example it is not and we use validation loss instead). Also, you should create the second validation set, because you will use the first as a training set for Bayesian Optimization.
You may need to install the optimizer pip install bayesian-optimization
from functools import lru_cache
# The function we'll optimize
@lru_cache(1000)
def get_accuracy(n_d:Int, n_a:Int, n_steps:Int):
model = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=n_d, n_a=n_a, n_steps=n_steps, gamma=1.5)
learn = Learner(dls, model, CrossEntropyLossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])
learn.fit_one_cycle(5)
return float(learn.validate(dl=learn.dls.valid)[1])
This implementation of Bayesian Optimization doesn't work naturally with descreet values. That's why we use wrapper with lru_cache
.
def fit_accuracy(pow_n_d, pow_n_a, pow_n_steps):
n_d, n_a, n_steps = map(lambda x: 2**int(x), (pow_n_d, pow_n_a, pow_n_steps))
return get_accuracy(n_d, n_a, n_steps)
from bayes_opt import BayesianOptimization
# Bounded region of parameter space
pbounds = {'pow_n_d': (0, 8), 'pow_n_a': (0, 8), 'pow_n_steps': (0, 4)}
optimizer = BayesianOptimization(
f=fit_accuracy,
pbounds=pbounds,
)
optimizer.maximize(
init_points=15,
n_iter=100,
)
| iter | target | pow_n_a | pow_n_d | pow_n_... |
-------------------------------------------------------------
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.404888 |
0.432834 |
0.793885 |
00:10 |
1 |
0.367979 |
0.384840 |
0.818600 |
00:09 |
2 |
0.366444 |
0.372005 |
0.819708 |
00:09 |
3 |
0.362771 |
0.366949 |
0.823511 |
00:10 |
4 |
0.353682 |
0.367132 |
0.823511 |
00:10 |
| [0m 1 [0m | [0m 0.8235 [0m | [0m 0.9408 [0m | [0m 1.898 [0m | [0m 1.652 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.393301 |
0.449742 |
0.810836 |
00:08 |
1 |
0.379140 |
0.413773 |
0.815589 |
00:07 |
2 |
0.355790 |
0.388907 |
0.822560 |
00:07 |
3 |
0.349984 |
0.362671 |
0.828739 |
00:07 |
4 |
0.348000 |
0.360150 |
0.827313 |
00:07 |
| [95m 2 [0m | [95m 0.8273 [0m | [95m 4.262 [0m | [95m 5.604 [0m | [95m 0.2437 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.451572 |
0.434189 |
0.781210 |
00:12 |
1 |
0.423763 |
0.413420 |
0.805450 |
00:12 |
2 |
0.398922 |
0.408688 |
0.814164 |
00:12 |
3 |
0.390981 |
0.392398 |
0.808935 |
00:12 |
4 |
0.376418 |
0.382250 |
0.817174 |
00:12 |
| [0m 3 [0m | [0m 0.8172 [0m | [0m 7.233 [0m | [0m 6.471 [0m | [0m 2.508 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.403187 |
0.413986 |
0.798162 |
00:07 |
1 |
0.398544 |
0.390102 |
0.820184 |
00:07 |
2 |
0.390569 |
0.389703 |
0.825253 |
00:07 |
3 |
0.375426 |
0.385706 |
0.826996 |
00:07 |
4 |
0.370446 |
0.383366 |
0.831115 |
00:06 |
| [95m 4 [0m | [95m 0.8311 [0m | [95m 5.935 [0m | [95m 1.241 [0m | [95m 0.3809 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.464145 |
0.458641 |
0.751267 |
00:18 |
1 |
0.424691 |
0.436968 |
0.788023 |
00:18 |
2 |
0.431576 |
0.436581 |
0.775824 |
00:18 |
3 |
0.432143 |
0.437062 |
0.759506 |
00:18 |
4 |
0.429915 |
0.438332 |
0.758555 |
00:18 |
| [0m 5 [0m | [0m 0.7586 [0m | [0m 2.554 [0m | [0m 0.4992 [0m | [0m 3.111 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.470359 |
0.475826 |
0.748891 |
00:12 |
1 |
0.411564 |
0.409433 |
0.797053 |
00:12 |
2 |
0.392718 |
0.397363 |
0.809727 |
00:12 |
3 |
0.387564 |
0.380033 |
0.814322 |
00:12 |
4 |
0.374153 |
0.378258 |
0.818916 |
00:12 |
| [0m 6 [0m | [0m 0.8189 [0m | [0m 4.592 [0m | [0m 2.138 [0m | [0m 2.824 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.547042 |
0.588752 |
0.754119 |
00:18 |
1 |
0.491731 |
0.469795 |
0.771863 |
00:18 |
2 |
0.454340 |
0.433961 |
0.775190 |
00:18 |
3 |
0.424386 |
0.432385 |
0.782953 |
00:18 |
4 |
0.397645 |
0.406420 |
0.805767 |
00:19 |
| [0m 7 [0m | [0m 0.8058 [0m | [0m 6.186 [0m | [0m 7.016 [0m | [0m 3.316 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.485245 |
0.487635 |
0.751109 |
00:18 |
1 |
0.450832 |
0.446423 |
0.750317 |
00:18 |
2 |
0.448203 |
0.449419 |
0.755228 |
00:18 |
3 |
0.430258 |
0.443562 |
0.744297 |
00:18 |
4 |
0.429821 |
0.437173 |
0.761565 |
00:18 |
| [0m 8 [0m | [0m 0.7616 [0m | [0m 2.018 [0m | [0m 1.316 [0m | [0m 3.675 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.458425 |
0.455733 |
0.751584 |
00:12 |
1 |
0.439781 |
0.467807 |
0.751109 |
00:12 |
2 |
0.420331 |
0.432216 |
0.775190 |
00:12 |
3 |
0.421012 |
0.421412 |
0.782319 |
00:12 |
4 |
0.401828 |
0.413434 |
0.801014 |
00:12 |
| [0m 9 [0m | [0m 0.801 [0m | [0m 2.051 [0m | [0m 1.958 [0m | [0m 2.332 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.546997 |
0.506728 |
0.761407 |
00:18 |
1 |
0.489712 |
0.439324 |
0.799588 |
00:18 |
2 |
0.448558 |
0.448419 |
0.786122 |
00:18 |
3 |
0.436869 |
0.435375 |
0.801648 |
00:18 |
4 |
0.417128 |
0.421093 |
0.798321 |
00:18 |
| [0m 10 [0m | [0m 0.7983 [0m | [0m 5.203 [0m | [0m 7.719 [0m | [0m 3.407 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.380781 |
0.463409 |
0.786439 |
00:07 |
1 |
0.359212 |
0.461147 |
0.798321 |
00:07 |
2 |
0.351414 |
0.368950 |
0.822719 |
00:07 |
3 |
0.347257 |
0.367056 |
0.829373 |
00:07 |
4 |
0.337212 |
0.362375 |
0.830799 |
00:07 |
| [0m 11 [0m | [0m 0.8308 [0m | [0m 6.048 [0m | [0m 4.376 [0m | [0m 0.08141 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.430772 |
0.430897 |
0.767744 |
00:12 |
1 |
0.402611 |
0.432137 |
0.764259 |
00:12 |
2 |
0.407579 |
0.409651 |
0.812104 |
00:12 |
3 |
0.374988 |
0.391822 |
0.816698 |
00:12 |
4 |
0.378011 |
0.389278 |
0.816065 |
00:12 |
| [0m 12 [0m | [0m 0.8161 [0m | [0m 7.083 [0m | [0m 1.385 [0m | [0m 2.806 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.402018 |
0.412051 |
0.812262 |
00:09 |
1 |
0.372804 |
0.464937 |
0.811629 |
00:09 |
2 |
0.368274 |
0.384675 |
0.820184 |
00:09 |
3 |
0.364502 |
0.371920 |
0.820659 |
00:09 |
4 |
0.348998 |
0.369445 |
0.823828 |
00:09 |
| [0m 13 [0m | [0m 0.8238 [0m | [0m 4.812 [0m | [0m 3.785 [0m | [0m 1.396 [0m |
| [0m 14 [0m | [0m 0.8172 [0m | [0m 7.672 [0m | [0m 6.719 [0m | [0m 2.72 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.476033 |
0.442598 |
0.803549 |
00:12 |
1 |
0.405236 |
0.414015 |
0.788973 |
00:11 |
2 |
0.406291 |
0.451269 |
0.789449 |
00:11 |
3 |
0.391013 |
0.393243 |
0.816065 |
00:12 |
4 |
0.374160 |
0.377635 |
0.821451 |
00:12 |
| [0m 15 [0m | [0m 0.8215 [0m | [0m 6.464 [0m | [0m 7.954 [0m | [0m 2.647 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.390142 |
0.390678 |
0.810995 |
00:06 |
1 |
0.381717 |
0.382202 |
0.813055 |
00:06 |
2 |
0.368564 |
0.378705 |
0.823828 |
00:06 |
3 |
0.358858 |
0.368329 |
0.823511 |
00:07 |
4 |
0.353392 |
0.363913 |
0.825887 |
00:06 |
| [0m 16 [0m | [0m 0.8259 [0m | [0m 0.1229 [0m | [0m 7.83 [0m | [0m 0.3708 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.381215 |
0.422651 |
0.800697 |
00:06 |
1 |
0.377345 |
0.380863 |
0.815906 |
00:06 |
2 |
0.366631 |
0.370579 |
0.822877 |
00:06 |
3 |
0.362745 |
0.366619 |
0.823352 |
00:07 |
4 |
0.356861 |
0.364835 |
0.825887 |
00:07 |
| [0m 17 [0m | [0m 0.8259 [0m | [0m 0.03098 [0m | [0m 3.326 [0m | [0m 0.007025[0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.404604 |
0.443035 |
0.824461 |
00:07 |
1 |
0.361872 |
0.388880 |
0.823669 |
00:06 |
2 |
0.375164 |
0.369968 |
0.825095 |
00:06 |
3 |
0.352091 |
0.363823 |
0.827947 |
00:06 |
4 |
0.335458 |
0.362544 |
0.829373 |
00:07 |
| [0m 18 [0m | [0m 0.8294 [0m | [0m 7.81 [0m | [0m 7.976 [0m | [0m 0.0194 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.679292 |
0.677299 |
0.248891 |
00:05 |
1 |
0.675403 |
0.678406 |
0.248891 |
00:05 |
2 |
0.673259 |
0.673374 |
0.248891 |
00:06 |
3 |
0.674996 |
0.673514 |
0.248891 |
00:07 |
4 |
0.668813 |
0.673671 |
0.248891 |
00:07 |
| [0m 19 [0m | [0m 0.2489 [0m | [0m 0.4499 [0m | [0m 0.138 [0m | [0m 0.001101[0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.524201 |
0.528132 |
0.729880 |
00:30 |
1 |
0.419377 |
0.403198 |
0.812104 |
00:31 |
2 |
0.399398 |
0.418890 |
0.812421 |
00:31 |
3 |
0.381651 |
0.391744 |
0.819075 |
00:31 |
4 |
0.368742 |
0.377904 |
0.822085 |
00:31 |
| [0m 20 [0m | [0m 0.8221 [0m | [0m 0.0 [0m | [0m 6.575 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.681083 |
0.682397 |
0.248891 |
00:05 |
1 |
0.672935 |
0.679371 |
0.248891 |
00:06 |
2 |
0.675200 |
0.673466 |
0.248891 |
00:06 |
3 |
0.674251 |
0.673356 |
0.248891 |
00:06 |
4 |
0.668861 |
0.673186 |
0.248891 |
00:06 |
| [0m 21 [0m | [0m 0.2489 [0m | [0m 8.0 [0m | [0m 0.0 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.407246 |
0.432203 |
0.801331 |
00:10 |
1 |
0.385086 |
0.399513 |
0.811312 |
00:11 |
2 |
0.377365 |
0.384121 |
0.816065 |
00:12 |
3 |
0.366855 |
0.371010 |
0.823194 |
00:12 |
4 |
0.361931 |
0.368933 |
0.825095 |
00:12 |
| [0m 22 [0m | [0m 0.8251 [0m | [0m 0.0 [0m | [0m 4.502 [0m | [0m 2.193 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.493623 |
0.476921 |
0.766001 |
00:30 |
1 |
0.441126 |
0.443774 |
0.776774 |
00:31 |
2 |
0.424523 |
0.437125 |
0.783904 |
00:31 |
3 |
0.402457 |
0.408628 |
0.795944 |
00:31 |
4 |
0.439420 |
0.431756 |
0.788973 |
00:32 |
| [0m 23 [0m | [0m 0.789 [0m | [0m 8.0 [0m | [0m 3.702 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.515615 |
0.513919 |
0.751109 |
00:31 |
1 |
0.462674 |
0.495322 |
0.751584 |
00:31 |
2 |
0.465430 |
0.483685 |
0.751267 |
00:31 |
3 |
0.481308 |
0.495375 |
0.755070 |
00:31 |
4 |
0.481324 |
0.491275 |
0.754911 |
00:32 |
| [0m 24 [0m | [0m 0.7549 [0m | [0m 6.009 [0m | [0m 0.0 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.422837 |
0.403953 |
0.819392 |
00:06 |
1 |
0.380753 |
0.367345 |
0.826838 |
00:06 |
2 |
0.353045 |
0.365174 |
0.830006 |
00:07 |
3 |
0.348628 |
0.364282 |
0.826362 |
00:07 |
4 |
0.343561 |
0.361509 |
0.829214 |
00:07 |
| [0m 25 [0m | [0m 0.8292 [0m | [0m 3.522 [0m | [0m 8.0 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.807766 |
1.307279 |
0.481622 |
00:31 |
1 |
0.513308 |
0.499470 |
0.783587 |
00:32 |
2 |
0.445906 |
0.492620 |
0.798004 |
00:31 |
3 |
0.385094 |
0.399986 |
0.807509 |
00:32 |
4 |
0.387228 |
0.384739 |
0.817015 |
00:31 |
| [0m 26 [0m | [0m 0.817 [0m | [0m 0.0 [0m | [0m 8.0 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.442076 |
0.491338 |
0.755387 |
00:31 |
1 |
0.441078 |
0.443674 |
0.760773 |
00:31 |
2 |
0.417575 |
0.418758 |
0.792142 |
00:31 |
3 |
0.410825 |
0.417581 |
0.788498 |
00:34 |
4 |
0.403407 |
0.410941 |
0.798321 |
00:46 |
| [0m 27 [0m | [0m 0.7983 [0m | [0m 0.0 [0m | [0m 0.0 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.407006 |
0.419679 |
0.792142 |
00:08 |
1 |
0.390913 |
0.392631 |
0.810520 |
00:08 |
2 |
0.365560 |
0.394330 |
0.817491 |
00:08 |
3 |
0.378459 |
0.387244 |
0.820659 |
00:08 |
4 |
0.375275 |
0.385417 |
0.828897 |
00:08 |
| [0m 28 [0m | [0m 0.8289 [0m | [0m 3.379 [0m | [0m 2.848 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.430604 |
0.469592 |
0.781210 |
00:45 |
1 |
0.423074 |
0.429704 |
0.797529 |
00:45 |
2 |
0.400120 |
0.393398 |
0.810995 |
00:45 |
3 |
0.382361 |
0.390651 |
0.816065 |
00:46 |
4 |
0.389520 |
0.401878 |
0.807193 |
00:46 |
| [0m 29 [0m | [0m 0.8072 [0m | [0m 0.0 [0m | [0m 2.588 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.396348 |
0.397454 |
0.806717 |
00:08 |
1 |
0.383342 |
0.386023 |
0.819550 |
00:07 |
2 |
0.369493 |
0.374401 |
0.820025 |
00:07 |
3 |
0.356015 |
0.366535 |
0.826204 |
00:08 |
4 |
0.341073 |
0.365241 |
0.826204 |
00:08 |
| [0m 30 [0m | [0m 0.8262 [0m | [0m 1.217 [0m | [0m 5.622 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.809077 |
0.480591 |
0.782795 |
00:45 |
1 |
0.571318 |
0.497731 |
0.739068 |
00:45 |
2 |
0.514562 |
0.461726 |
0.781527 |
00:45 |
3 |
0.439822 |
0.451722 |
0.787231 |
00:44 |
4 |
0.419881 |
0.422125 |
0.801648 |
00:45 |
| [0m 31 [0m | [0m 0.8016 [0m | [0m 8.0 [0m | [0m 8.0 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.410521 |
0.435045 |
0.810044 |
00:08 |
1 |
0.363098 |
0.378001 |
0.821926 |
00:08 |
2 |
0.359525 |
0.364477 |
0.827788 |
00:08 |
3 |
0.354005 |
0.366507 |
0.821610 |
00:08 |
4 |
0.347293 |
0.362657 |
0.829373 |
00:08 |
| [0m 32 [0m | [0m 0.8294 [0m | [0m 5.864 [0m | [0m 8.0 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.498376 |
0.436696 |
0.794043 |
00:16 |
1 |
0.411699 |
0.435537 |
0.801331 |
00:16 |
2 |
0.385327 |
0.396916 |
0.820184 |
00:16 |
3 |
0.382020 |
0.389856 |
0.813371 |
00:16 |
4 |
0.373869 |
0.377804 |
0.820817 |
00:15 |
| [0m 33 [0m | [0m 0.8208 [0m | [0m 1.776 [0m | [0m 8.0 [0m | [0m 2.212 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.404653 |
0.440106 |
0.772180 |
00:11 |
1 |
0.377931 |
0.393715 |
0.817332 |
00:11 |
2 |
0.373221 |
0.379273 |
0.826838 |
00:11 |
3 |
0.359682 |
0.362844 |
0.828422 |
00:11 |
4 |
0.340384 |
0.363072 |
0.828897 |
00:11 |
| [0m 34 [0m | [0m 0.8289 [0m | [0m 5.777 [0m | [0m 2.2 [0m | [0m 1.31 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.520308 |
0.503207 |
0.749208 |
00:45 |
1 |
0.472501 |
0.451469 |
0.780418 |
00:45 |
2 |
0.454686 |
0.429175 |
0.784854 |
00:45 |
3 |
0.400800 |
0.413727 |
0.795469 |
00:44 |
4 |
0.405604 |
0.409770 |
0.801648 |
00:45 |
| [0m 35 [0m | [0m 0.8016 [0m | [0m 2.748 [0m | [0m 5.915 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.504537 |
0.501541 |
0.750317 |
00:45 |
1 |
0.465937 |
0.477715 |
0.773289 |
00:45 |
2 |
0.435364 |
0.481415 |
0.766635 |
00:45 |
3 |
0.425434 |
0.442198 |
0.772814 |
00:45 |
4 |
0.425779 |
0.458947 |
0.771863 |
00:45 |
| [0m 36 [0m | [0m 0.7719 [0m | [0m 6.251 [0m | [0m 2.532 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.420782 |
0.420721 |
0.791350 |
00:10 |
1 |
0.403576 |
0.408376 |
0.800222 |
00:10 |
2 |
0.390236 |
0.393624 |
0.820342 |
00:11 |
3 |
0.377777 |
0.389657 |
0.821610 |
00:11 |
4 |
0.382809 |
0.386011 |
0.820976 |
00:11 |
| [0m 37 [0m | [0m 0.821 [0m | [0m 5.093 [0m | [0m 0.172 [0m | [0m 1.64 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.393575 |
0.397811 |
0.812262 |
00:08 |
1 |
0.378272 |
0.381915 |
0.815748 |
00:08 |
2 |
0.364799 |
0.369214 |
0.824620 |
00:08 |
3 |
0.355757 |
0.364554 |
0.826996 |
00:08 |
4 |
0.342090 |
0.362723 |
0.824303 |
00:08 |
| [0m 38 [0m | [0m 0.8243 [0m | [0m 8.0 [0m | [0m 5.799 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.393693 |
0.396980 |
0.822085 |
00:11 |
1 |
0.361231 |
0.393146 |
0.813847 |
00:11 |
2 |
0.345645 |
0.379510 |
0.823986 |
00:11 |
3 |
0.349778 |
0.367077 |
0.826679 |
00:11 |
4 |
0.342390 |
0.362027 |
0.827788 |
00:11 |
| [0m 39 [0m | [0m 0.8278 [0m | [0m 1.62 [0m | [0m 3.832 [0m | [0m 1.151 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.832737 |
0.491002 |
0.771546 |
00:43 |
1 |
0.627948 |
0.553552 |
0.764734 |
00:43 |
2 |
0.498901 |
0.467162 |
0.791984 |
00:46 |
3 |
0.431196 |
0.444576 |
0.785646 |
00:43 |
4 |
0.399745 |
0.427060 |
0.796578 |
00:42 |
| [0m 40 [0m | [0m 0.7966 [0m | [0m 2.198 [0m | [0m 8.0 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.511301 |
0.514401 |
0.751267 |
00:43 |
1 |
0.447332 |
0.445157 |
0.751109 |
00:43 |
2 |
0.451125 |
0.438327 |
0.750951 |
00:42 |
3 |
0.445883 |
0.443266 |
0.751267 |
00:42 |
4 |
0.444816 |
0.438459 |
0.764100 |
00:42 |
| [0m 41 [0m | [0m 0.7641 [0m | [0m 8.0 [0m | [0m 1.03 [0m | [0m 4.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.408504 |
0.413275 |
0.797212 |
00:15 |
1 |
0.392707 |
0.399085 |
0.805767 |
00:15 |
2 |
0.379938 |
0.395550 |
0.817807 |
00:15 |
3 |
0.375288 |
0.383186 |
0.820817 |
00:15 |
4 |
0.360417 |
0.375098 |
0.823194 |
00:16 |
| [0m 42 [0m | [0m 0.8232 [0m | [0m 0.0 [0m | [0m 2.504 [0m | [0m 2.135 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.399371 |
0.415196 |
0.801014 |
00:07 |
1 |
0.367804 |
0.392020 |
0.810995 |
00:06 |
2 |
0.362288 |
0.385124 |
0.820659 |
00:07 |
3 |
0.344728 |
0.371339 |
0.823669 |
00:07 |
4 |
0.345769 |
0.362059 |
0.829373 |
00:07 |
| [0m 43 [0m | [0m 0.8294 [0m | [0m 0.0 [0m | [0m 5.441 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.397157 |
0.431003 |
0.803866 |
00:06 |
1 |
0.394964 |
0.396448 |
0.810361 |
00:06 |
2 |
0.378584 |
0.387943 |
0.820659 |
00:07 |
3 |
0.371601 |
0.386186 |
0.818283 |
00:07 |
4 |
0.369759 |
0.384339 |
0.827630 |
00:07 |
| [0m 44 [0m | [0m 0.8276 [0m | [0m 4.636 [0m | [0m 1.476 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.408654 |
0.426806 |
0.791191 |
00:12 |
1 |
0.394184 |
0.406586 |
0.786439 |
00:12 |
2 |
0.369625 |
0.372680 |
0.822560 |
00:12 |
3 |
0.349444 |
0.368142 |
0.823828 |
00:12 |
4 |
0.351684 |
0.363406 |
0.826204 |
00:12 |
| [0m 45 [0m | [0m 0.8262 [0m | [0m 0.0 [0m | [0m 7.071 [0m | [0m 2.071 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.400293 |
0.416098 |
0.811629 |
00:08 |
1 |
0.377387 |
0.433395 |
0.807034 |
00:08 |
2 |
0.368131 |
0.395448 |
0.796420 |
00:08 |
3 |
0.367750 |
0.376879 |
0.817174 |
00:08 |
4 |
0.362124 |
0.371432 |
0.821134 |
00:08 |
| [0m 46 [0m | [0m 0.8211 [0m | [0m 4.26 [0m | [0m 6.934 [0m | [0m 1.79 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.404579 |
0.437443 |
0.814797 |
00:07 |
1 |
0.375342 |
0.380416 |
0.824937 |
00:07 |
2 |
0.365835 |
0.377617 |
0.812738 |
00:07 |
3 |
0.354619 |
0.364503 |
0.827471 |
00:07 |
4 |
0.340603 |
0.363488 |
0.827947 |
00:07 |
| [0m 47 [0m | [0m 0.8279 [0m | [0m 6.579 [0m | [0m 6.485 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.384890 |
0.440342 |
0.812579 |
00:08 |
1 |
0.371483 |
0.387200 |
0.813847 |
00:09 |
2 |
0.365951 |
0.378071 |
0.818283 |
00:09 |
3 |
0.362965 |
0.369994 |
0.821610 |
00:09 |
4 |
0.356483 |
0.365151 |
0.826521 |
00:09 |
| [0m 48 [0m | [0m 0.8265 [0m | [0m 8.0 [0m | [0m 4.293 [0m | [0m 1.74 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.386308 |
0.389250 |
0.815431 |
00:08 |
1 |
0.368402 |
0.389338 |
0.814956 |
00:09 |
2 |
0.362211 |
0.377196 |
0.824778 |
00:09 |
3 |
0.356135 |
0.362951 |
0.829531 |
00:09 |
4 |
0.341577 |
0.362476 |
0.830799 |
00:09 |
| [0m 49 [0m | [0m 0.8308 [0m | [0m 7.909 [0m | [0m 7.827 [0m | [0m 1.323 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.426044 |
0.422882 |
0.791984 |
00:08 |
1 |
0.375756 |
0.381810 |
0.817491 |
00:09 |
2 |
0.363932 |
0.375904 |
0.818916 |
00:09 |
3 |
0.349442 |
0.365052 |
0.823986 |
00:09 |
4 |
0.344509 |
0.363027 |
0.830323 |
00:09 |
| [0m 50 [0m | [0m 0.8303 [0m | [0m 4.946 [0m | [0m 1.246 [0m | [0m 1.589 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.388522 |
0.431909 |
0.820976 |
00:07 |
1 |
0.372532 |
0.448644 |
0.751109 |
00:07 |
2 |
0.358823 |
0.373322 |
0.823669 |
00:07 |
3 |
0.352838 |
0.362424 |
0.831591 |
00:07 |
4 |
0.352949 |
0.361356 |
0.831432 |
00:07 |
| [95m 51 [0m | [95m 0.8314 [0m | [95m 5.664 [0m | [95m 2.626 [0m | [95m 0.003048[0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.389195 |
0.390032 |
0.817332 |
00:06 |
1 |
0.369993 |
0.382199 |
0.819708 |
00:07 |
2 |
0.362801 |
0.373282 |
0.826521 |
00:06 |
3 |
0.359760 |
0.363597 |
0.824303 |
00:06 |
4 |
0.344525 |
0.362097 |
0.828897 |
00:07 |
| [0m 52 [0m | [0m 0.8289 [0m | [0m 1.287 [0m | [0m 3.505 [0m | [0m 0.06804 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.394940 |
0.403165 |
0.814639 |
00:06 |
1 |
0.371518 |
0.452118 |
0.806876 |
00:06 |
2 |
0.364734 |
0.377214 |
0.824461 |
00:06 |
3 |
0.347968 |
0.365335 |
0.823511 |
00:07 |
4 |
0.345476 |
0.363670 |
0.827155 |
00:07 |
| [0m 53 [0m | [0m 0.8272 [0m | [0m 1.606 [0m | [0m 7.998 [0m | [0m 0.2009 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.500322 |
0.519261 |
0.757129 |
00:10 |
1 |
0.413270 |
0.423630 |
0.801965 |
00:11 |
2 |
0.380234 |
0.395588 |
0.813371 |
00:12 |
3 |
0.361677 |
0.378123 |
0.817174 |
00:12 |
4 |
0.374629 |
0.373772 |
0.820025 |
00:12 |
| [0m 54 [0m | [0m 0.82 [0m | [0m 4.579 [0m | [0m 5.017 [0m | [0m 2.928 [0m |
| [0m 55 [0m | [0m 0.8259 [0m | [0m 0.02565 [0m | [0m 3.699 [0m | [0m 0.9808 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.452787 |
0.443697 |
0.768695 |
00:11 |
1 |
0.428332 |
0.415454 |
0.800697 |
00:11 |
2 |
0.396522 |
0.402850 |
0.807668 |
00:12 |
3 |
0.424802 |
0.414648 |
0.783587 |
00:12 |
4 |
0.385055 |
0.392359 |
0.801489 |
00:12 |
| [0m 56 [0m | [0m 0.8015 [0m | [0m 1.927 [0m | [0m 5.92 [0m | [0m 2.53 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.435597 |
0.438222 |
0.810836 |
00:19 |
1 |
0.399920 |
0.531189 |
0.770754 |
00:19 |
2 |
0.403408 |
0.409382 |
0.804816 |
00:18 |
3 |
0.363519 |
0.383823 |
0.815906 |
00:19 |
4 |
0.360030 |
0.377621 |
0.819708 |
00:19 |
| [0m 57 [0m | [0m 0.8197 [0m | [0m 0.7796 [0m | [0m 4.576 [0m | [0m 3.952 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.388445 |
0.420243 |
0.800539 |
00:07 |
1 |
0.372912 |
0.369659 |
0.827630 |
00:07 |
2 |
0.354443 |
0.366757 |
0.828105 |
00:07 |
3 |
0.352468 |
0.366038 |
0.822560 |
00:07 |
4 |
0.347822 |
0.362001 |
0.829690 |
00:07 |
| [0m 58 [0m | [0m 0.8297 [0m | [0m 3.525 [0m | [0m 4.198 [0m | [0m 0.02314 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.444627 |
0.431739 |
0.787072 |
00:12 |
1 |
0.392637 |
0.412985 |
0.799747 |
00:12 |
2 |
0.369733 |
0.396133 |
0.802440 |
00:12 |
3 |
0.365821 |
0.373095 |
0.820342 |
00:12 |
4 |
0.370486 |
0.371560 |
0.819392 |
00:12 |
| [0m 59 [0m | [0m 0.8194 [0m | [0m 6.711 [0m | [0m 3.848 [0m | [0m 2.395 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.396831 |
0.389045 |
0.809569 |
00:07 |
1 |
0.371171 |
0.375065 |
0.818600 |
00:07 |
2 |
0.350309 |
0.371795 |
0.824620 |
00:07 |
3 |
0.359700 |
0.363041 |
0.828739 |
00:07 |
4 |
0.345735 |
0.361556 |
0.830799 |
00:07 |
| [0m 60 [0m | [0m 0.8308 [0m | [0m 4.914 [0m | [0m 7.944 [0m | [0m 0.9998 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.422853 |
0.412691 |
0.804341 |
00:09 |
1 |
0.375209 |
0.394692 |
0.817174 |
00:09 |
2 |
0.365574 |
0.380376 |
0.820184 |
00:08 |
3 |
0.359143 |
0.363607 |
0.831115 |
00:08 |
4 |
0.347991 |
0.361650 |
0.827947 |
00:08 |
| [0m 61 [0m | [0m 0.8279 [0m | [0m 7.962 [0m | [0m 6.151 [0m | [0m 1.119 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.388147 |
0.405885 |
0.810678 |
00:07 |
1 |
0.367743 |
0.391867 |
0.807826 |
00:07 |
2 |
0.366964 |
0.362980 |
0.828739 |
00:07 |
3 |
0.363402 |
0.363396 |
0.829531 |
00:07 |
4 |
0.351094 |
0.362245 |
0.829214 |
00:07 |
| [0m 62 [0m | [0m 0.8292 [0m | [0m 2.583 [0m | [0m 6.996 [0m | [0m 0.008348[0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.432724 |
0.390516 |
0.815114 |
00:11 |
1 |
0.407100 |
0.401564 |
0.811153 |
00:11 |
2 |
0.359414 |
0.384463 |
0.820342 |
00:11 |
3 |
0.358061 |
0.371844 |
0.826362 |
00:12 |
4 |
0.345357 |
0.362986 |
0.831115 |
00:12 |
| [0m 63 [0m | [0m 0.8311 [0m | [0m 0.2249 [0m | [0m 8.0 [0m | [0m 2.823 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.418990 |
0.420463 |
0.808618 |
00:07 |
1 |
0.389830 |
0.398110 |
0.816223 |
00:08 |
2 |
0.382975 |
0.387620 |
0.814956 |
00:06 |
3 |
0.384093 |
0.379607 |
0.819392 |
00:06 |
4 |
0.358019 |
0.371140 |
0.823828 |
00:08 |
| [0m 64 [0m | [0m 0.8238 [0m | [0m 5.764 [0m | [0m 5.509 [0m | [0m 1.482 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.385147 |
0.399680 |
0.806242 |
00:09 |
1 |
0.376032 |
0.381131 |
0.822560 |
00:09 |
2 |
0.363870 |
0.378227 |
0.822402 |
00:09 |
3 |
0.351089 |
0.368790 |
0.826838 |
00:09 |
4 |
0.340404 |
0.361807 |
0.829214 |
00:09 |
| [0m 65 [0m | [0m 0.8292 [0m | [0m 1.048 [0m | [0m 2.939 [0m | [0m 1.922 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.526628 |
0.507236 |
0.746673 |
00:17 |
1 |
0.460229 |
0.455675 |
0.765684 |
00:19 |
2 |
0.417427 |
0.421368 |
0.785963 |
00:19 |
3 |
0.462800 |
0.458844 |
0.773923 |
00:19 |
4 |
0.449479 |
0.456627 |
0.783587 |
00:19 |
| [0m 66 [0m | [0m 0.7836 [0m | [0m 3.68 [0m | [0m 3.977 [0m | [0m 3.919 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.621678 |
0.559080 |
0.749049 |
00:10 |
1 |
0.457104 |
0.473610 |
0.758397 |
00:11 |
2 |
0.416287 |
0.416622 |
0.764575 |
00:13 |
3 |
0.388107 |
0.403844 |
0.811945 |
00:13 |
4 |
0.384231 |
0.396397 |
0.813055 |
00:13 |
| [0m 67 [0m | [0m 0.8131 [0m | [0m 5.907 [0m | [0m 0.9452 [0m | [0m 2.168 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.410105 |
0.416171 |
0.808618 |
00:11 |
1 |
0.381669 |
0.400109 |
0.809094 |
00:13 |
2 |
0.377539 |
0.403879 |
0.803074 |
00:13 |
3 |
0.374653 |
0.389122 |
0.808618 |
00:13 |
4 |
0.366356 |
0.380526 |
0.814005 |
00:13 |
| [0m 68 [0m | [0m 0.814 [0m | [0m 7.981 [0m | [0m 2.796 [0m | [0m 2.78 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.380539 |
0.420856 |
0.815589 |
00:07 |
1 |
0.368232 |
0.424276 |
0.803866 |
00:05 |
2 |
0.358194 |
0.378501 |
0.827155 |
00:06 |
3 |
0.353427 |
0.362224 |
0.829690 |
00:07 |
4 |
0.344443 |
0.361554 |
0.826838 |
00:07 |
| [0m 69 [0m | [0m 0.8268 [0m | [0m 7.223 [0m | [0m 3.762 [0m | [0m 0.6961 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.445924 |
0.488581 |
0.752693 |
00:17 |
1 |
0.410709 |
0.400962 |
0.813688 |
00:18 |
2 |
0.373518 |
0.393235 |
0.820184 |
00:18 |
3 |
0.364160 |
0.378920 |
0.820817 |
00:17 |
4 |
0.357551 |
0.371629 |
0.825412 |
00:17 |
| [0m 70 [0m | [0m 0.8254 [0m | [0m 0.009375[0m | [0m 5.081 [0m | [0m 3.79 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.410323 |
0.442670 |
0.814797 |
00:08 |
1 |
0.379279 |
0.405500 |
0.807034 |
00:08 |
2 |
0.387576 |
0.392448 |
0.819708 |
00:09 |
3 |
0.371622 |
0.389167 |
0.823035 |
00:09 |
4 |
0.374182 |
0.386964 |
0.825095 |
00:09 |
| [0m 71 [0m | [0m 0.8251 [0m | [0m 3.293 [0m | [0m 2.76 [0m | [0m 1.061 [0m |
| [0m 72 [0m | [0m 0.8308 [0m | [0m 4.589 [0m | [0m 7.13 [0m | [0m 0.003179[0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.391509 |
0.399966 |
0.806400 |
00:06 |
1 |
0.366694 |
0.405719 |
0.823828 |
00:07 |
2 |
0.359751 |
0.375496 |
0.822877 |
00:07 |
3 |
0.347678 |
0.361711 |
0.830799 |
00:07 |
4 |
0.336896 |
0.361922 |
0.828580 |
00:07 |
| [0m 73 [0m | [0m 0.8286 [0m | [0m 7.118 [0m | [0m 5.204 [0m | [0m 0.5939 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.435485 |
0.414883 |
0.808143 |
00:08 |
1 |
0.373591 |
0.417138 |
0.814005 |
00:09 |
2 |
0.369590 |
0.375724 |
0.820184 |
00:09 |
3 |
0.370829 |
0.368655 |
0.829531 |
00:09 |
4 |
0.346463 |
0.366307 |
0.825412 |
00:09 |
| [0m 74 [0m | [0m 0.8254 [0m | [0m 6.837 [0m | [0m 7.988 [0m | [0m 1.055 [0m |
| [0m 75 [0m | [0m 0.8259 [0m | [0m 0.6629 [0m | [0m 7.012 [0m | [0m 0.03222 [0m |
| [0m 76 [0m | [0m 0.8311 [0m | [0m 5.177 [0m | [0m 1.457 [0m | [0m 0.5857 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.455741 |
0.530008 |
0.794994 |
00:11 |
1 |
0.421961 |
0.423317 |
0.805292 |
00:11 |
2 |
0.405799 |
0.405729 |
0.807351 |
00:12 |
3 |
0.383895 |
0.395092 |
0.816857 |
00:12 |
4 |
0.378882 |
0.386044 |
0.818758 |
00:12 |
| [0m 77 [0m | [0m 0.8188 [0m | [0m 3.308 [0m | [0m 4.533 [0m | [0m 2.048 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.412261 |
0.410950 |
0.798954 |
00:08 |
1 |
0.382314 |
0.408471 |
0.803390 |
00:09 |
2 |
0.347488 |
0.387500 |
0.815273 |
00:09 |
3 |
0.343408 |
0.372050 |
0.821451 |
00:09 |
4 |
0.344963 |
0.366158 |
0.822719 |
00:09 |
| [0m 78 [0m | [0m 0.8227 [0m | [0m 0.4036 [0m | [0m 7.997 [0m | [0m 1.439 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.412856 |
0.410016 |
0.799113 |
00:08 |
1 |
0.410852 |
0.416405 |
0.788498 |
00:08 |
2 |
0.373897 |
0.384385 |
0.824303 |
00:09 |
3 |
0.353164 |
0.366129 |
0.822719 |
00:09 |
4 |
0.353253 |
0.362269 |
0.826362 |
00:09 |
| [0m 79 [0m | [0m 0.8264 [0m | [0m 3.438 [0m | [0m 7.982 [0m | [0m 1.829 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.419316 |
0.408936 |
0.798162 |
00:08 |
1 |
0.393826 |
0.390526 |
0.820184 |
00:08 |
2 |
0.372879 |
0.374823 |
0.822719 |
00:08 |
3 |
0.358019 |
0.370913 |
0.820342 |
00:08 |
4 |
0.346020 |
0.362252 |
0.829690 |
00:08 |
| [0m 80 [0m | [0m 0.8297 [0m | [0m 6.88 [0m | [0m 2.404 [0m | [0m 1.666 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.433481 |
0.437320 |
0.790082 |
00:18 |
1 |
0.415280 |
0.402946 |
0.814164 |
00:18 |
2 |
0.365575 |
0.376285 |
0.822877 |
00:18 |
3 |
0.363206 |
0.371865 |
0.820501 |
00:18 |
4 |
0.356401 |
0.370252 |
0.823828 |
00:18 |
| [0m 81 [0m | [0m 0.8238 [0m | [0m 0.03221 [0m | [0m 1.306 [0m | [0m 3.909 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.393150 |
0.420964 |
0.783745 |
00:07 |
1 |
0.375065 |
0.371380 |
0.823986 |
00:07 |
2 |
0.362952 |
0.387037 |
0.813688 |
00:07 |
3 |
0.347245 |
0.370225 |
0.824937 |
00:07 |
4 |
0.348406 |
0.361420 |
0.830640 |
00:07 |
| [0m 82 [0m | [0m 0.8306 [0m | [0m 1.575 [0m | [0m 2.689 [0m | [0m 0.8684 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.395530 |
0.397430 |
0.818600 |
00:06 |
1 |
0.358679 |
0.396773 |
0.818283 |
00:07 |
2 |
0.349305 |
0.372877 |
0.823828 |
00:07 |
3 |
0.347346 |
0.363006 |
0.828422 |
00:07 |
4 |
0.335652 |
0.362567 |
0.830957 |
00:07 |
| [0m 83 [0m | [0m 0.831 [0m | [0m 2.765 [0m | [0m 5.439 [0m | [0m 0.04047 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.380831 |
0.386072 |
0.814322 |
00:05 |
1 |
0.369778 |
0.392521 |
0.810520 |
00:07 |
2 |
0.368286 |
0.383131 |
0.816857 |
00:07 |
3 |
0.356585 |
0.367839 |
0.821768 |
00:07 |
4 |
0.344722 |
0.366639 |
0.825253 |
00:07 |
| [0m 84 [0m | [0m 0.8253 [0m | [0m 0.1961 [0m | [0m 4.123 [0m | [0m 0.02039 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.502605 |
0.463703 |
0.772180 |
00:11 |
1 |
0.407231 |
0.404386 |
0.804499 |
00:13 |
2 |
0.406849 |
0.411254 |
0.817491 |
00:12 |
3 |
0.379910 |
0.389118 |
0.817174 |
00:12 |
4 |
0.370964 |
0.379835 |
0.821293 |
00:12 |
| [0m 85 [0m | [0m 0.8213 [0m | [0m 7.937 [0m | [0m 7.939 [0m | [0m 2.895 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.402574 |
0.431237 |
0.786755 |
00:11 |
1 |
0.380494 |
0.392253 |
0.817966 |
00:11 |
2 |
0.363284 |
0.393907 |
0.815748 |
00:11 |
3 |
0.355680 |
0.368488 |
0.822560 |
00:12 |
4 |
0.362978 |
0.367755 |
0.823511 |
00:12 |
| [0m 86 [0m | [0m 0.8235 [0m | [0m 0.06921 [0m | [0m 5.7 [0m | [0m 2.778 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.448572 |
0.451255 |
0.789766 |
00:10 |
1 |
0.417093 |
0.411838 |
0.808143 |
00:11 |
2 |
0.400799 |
0.400185 |
0.816223 |
00:12 |
3 |
0.378641 |
0.385082 |
0.820342 |
00:12 |
4 |
0.365278 |
0.380320 |
0.818441 |
00:11 |
| [0m 87 [0m | [0m 0.8184 [0m | [0m 7.965 [0m | [0m 5.261 [0m | [0m 2.661 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.399453 |
0.429288 |
0.783112 |
00:07 |
1 |
0.368987 |
0.376985 |
0.825729 |
00:07 |
2 |
0.358226 |
0.372103 |
0.830165 |
00:07 |
3 |
0.348940 |
0.362069 |
0.831115 |
00:07 |
4 |
0.341437 |
0.361470 |
0.830482 |
00:07 |
| [0m 88 [0m | [0m 0.8305 [0m | [0m 2.792 [0m | [0m 7.917 [0m | [0m 0.761 [0m |
| [0m 89 [0m | [0m 0.8294 [0m | [0m 7.995 [0m | [0m 7.186 [0m | [0m 0.1199 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.390914 |
0.403780 |
0.799905 |
00:07 |
1 |
0.375241 |
0.400406 |
0.821610 |
00:07 |
2 |
0.359710 |
0.373233 |
0.826046 |
00:07 |
3 |
0.351108 |
0.367255 |
0.823986 |
00:07 |
4 |
0.348230 |
0.362827 |
0.830482 |
00:07 |
| [0m 90 [0m | [0m 0.8305 [0m | [0m 2.526 [0m | [0m 3.741 [0m | [0m 0.1186 [0m |
| [0m 91 [0m | [0m 0.8294 [0m | [0m 0.03285 [0m | [0m 5.742 [0m | [0m 0.9747 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.384485 |
0.413711 |
0.816857 |
00:06 |
1 |
0.358383 |
0.382125 |
0.813055 |
00:06 |
2 |
0.349632 |
0.380644 |
0.825570 |
00:07 |
3 |
0.350005 |
0.363857 |
0.833492 |
00:07 |
4 |
0.341342 |
0.362533 |
0.829848 |
00:07 |
| [0m 92 [0m | [0m 0.8298 [0m | [0m 4.112e-0[0m | [0m 8.0 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.461806 |
0.454314 |
0.752218 |
00:11 |
1 |
0.429429 |
0.452721 |
0.768061 |
00:13 |
2 |
0.408792 |
0.422315 |
0.785013 |
00:12 |
3 |
0.400290 |
0.411134 |
0.805292 |
00:12 |
4 |
0.396974 |
0.409337 |
0.804658 |
00:12 |
| [0m 93 [0m | [0m 0.8047 [0m | [0m 4.563 [0m | [0m 0.6868 [0m | [0m 2.461 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.464920 |
0.440214 |
0.772655 |
00:11 |
1 |
0.401611 |
0.400240 |
0.805767 |
00:11 |
2 |
0.391905 |
0.414512 |
0.798638 |
00:13 |
3 |
0.391561 |
0.407271 |
0.805608 |
00:13 |
4 |
0.387596 |
0.397586 |
0.814639 |
00:13 |
| [0m 94 [0m | [0m 0.8146 [0m | [0m 4.697 [0m | [0m 3.412 [0m | [0m 2.514 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.417917 |
0.423588 |
0.755070 |
00:05 |
1 |
0.390158 |
0.396916 |
0.817649 |
00:07 |
2 |
0.374562 |
0.389764 |
0.828422 |
00:07 |
3 |
0.368317 |
0.385268 |
0.822719 |
00:07 |
4 |
0.373367 |
0.385092 |
0.827471 |
00:07 |
| [0m 95 [0m | [0m 0.8275 [0m | [0m 5.04 [0m | [0m 0.4492 [0m | [0m 0.5899 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.368991 |
0.388251 |
0.822402 |
00:07 |
1 |
0.363111 |
0.399670 |
0.814797 |
00:08 |
2 |
0.370390 |
0.377726 |
0.818441 |
00:10 |
3 |
0.353493 |
0.368603 |
0.823828 |
00:10 |
4 |
0.355493 |
0.368281 |
0.823669 |
00:09 |
| [0m 96 [0m | [0m 0.8237 [0m | [0m 0.6025 [0m | [0m 2.712 [0m | [0m 1.166 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.426161 |
0.417567 |
0.793726 |
00:07 |
1 |
0.397633 |
0.392852 |
0.818124 |
00:08 |
2 |
0.371073 |
0.371678 |
0.821768 |
00:09 |
3 |
0.351716 |
0.361607 |
0.831749 |
00:10 |
4 |
0.347291 |
0.359881 |
0.832066 |
00:09 |
| [95m 97 [0m | [95m 0.8321 [0m | [95m 6.389 [0m | [95m 3.648 [0m | [95m 1.016 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.408937 |
0.440755 |
0.785805 |
00:12 |
1 |
0.369800 |
0.379622 |
0.819867 |
00:12 |
2 |
0.356048 |
0.370119 |
0.827630 |
00:11 |
3 |
0.354271 |
0.366091 |
0.826362 |
00:11 |
4 |
0.359324 |
0.366020 |
0.827313 |
00:12 |
| [0m 98 [0m | [0m 0.8273 [0m | [0m 0.5927 [0m | [0m 1.715 [0m | [0m 2.847 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.443831 |
0.469775 |
0.753802 |
00:11 |
1 |
0.411636 |
0.423362 |
0.787864 |
00:11 |
2 |
0.411669 |
0.421874 |
0.797529 |
00:13 |
3 |
0.392242 |
0.395377 |
0.816382 |
00:12 |
4 |
0.383947 |
0.390270 |
0.818441 |
00:12 |
| [0m 99 [0m | [0m 0.8184 [0m | [0m 6.602 [0m | [0m 2.266 [0m | [0m 2.535 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.413688 |
0.469193 |
0.811153 |
00:08 |
1 |
0.383908 |
0.401520 |
0.815589 |
00:08 |
2 |
0.363908 |
0.375178 |
0.824937 |
00:08 |
3 |
0.360694 |
0.366536 |
0.828739 |
00:08 |
4 |
0.341851 |
0.362886 |
0.830165 |
00:08 |
| [0m 100 [0m | [0m 0.8302 [0m | [0m 3.099 [0m | [0m 6.058 [0m | [0m 1.058 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.404324 |
0.402310 |
0.807668 |
00:09 |
1 |
0.383628 |
0.428126 |
0.780577 |
00:09 |
2 |
0.360917 |
0.375198 |
0.826996 |
00:09 |
3 |
0.349922 |
0.367114 |
0.828264 |
00:09 |
4 |
0.338715 |
0.363762 |
0.830165 |
00:09 |
| [0m 101 [0m | [0m 0.8302 [0m | [0m 4.268 [0m | [0m 5.28 [0m | [0m 1.397 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.381317 |
0.410546 |
0.815431 |
00:06 |
1 |
0.355632 |
0.424577 |
0.802440 |
00:06 |
2 |
0.361193 |
0.377675 |
0.818124 |
00:07 |
3 |
0.355751 |
0.363205 |
0.827313 |
00:07 |
4 |
0.338802 |
0.362403 |
0.827471 |
00:07 |
| [0m 102 [0m | [0m 0.8275 [0m | [0m 4.735 [0m | [0m 3.398 [0m | [0m 0.01904 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.381150 |
0.398499 |
0.813213 |
00:07 |
1 |
0.370997 |
0.413040 |
0.798162 |
00:07 |
2 |
0.364154 |
0.369066 |
0.823352 |
00:07 |
3 |
0.354434 |
0.362925 |
0.825095 |
00:07 |
4 |
0.348736 |
0.362125 |
0.824461 |
00:07 |
| [0m 103 [0m | [0m 0.8245 [0m | [0m 0.0 [0m | [0m 6.371 [0m | [0m 9.721e-0[0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.388378 |
0.412839 |
0.810361 |
00:06 |
1 |
0.377533 |
0.380432 |
0.822560 |
00:06 |
2 |
0.356242 |
0.372250 |
0.824937 |
00:07 |
3 |
0.346274 |
0.364236 |
0.830323 |
00:07 |
4 |
0.349273 |
0.362597 |
0.830482 |
00:07 |
| [0m 104 [0m | [0m 0.8305 [0m | [0m 1.683 [0m | [0m 6.381 [0m | [0m 0.8564 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.410286 |
0.413506 |
0.801648 |
00:07 |
1 |
0.386787 |
0.397328 |
0.812262 |
00:07 |
2 |
0.365268 |
0.390419 |
0.825570 |
00:07 |
3 |
0.368566 |
0.386955 |
0.828264 |
00:07 |
4 |
0.362751 |
0.383124 |
0.830482 |
00:07 |
| [0m 105 [0m | [0m 0.8305 [0m | [0m 6.387 [0m | [0m 2.333 [0m | [0m 0.4877 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.432347 |
0.450028 |
0.780894 |
00:17 |
1 |
0.414766 |
0.402565 |
0.809569 |
00:17 |
2 |
0.382495 |
0.382281 |
0.812421 |
00:18 |
3 |
0.366852 |
0.373158 |
0.822085 |
00:18 |
4 |
0.353692 |
0.368471 |
0.824461 |
00:18 |
| [0m 106 [0m | [0m 0.8245 [0m | [0m 0.9452 [0m | [0m 6.902 [0m | [0m 3.991 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.397604 |
0.407966 |
0.814797 |
00:06 |
1 |
0.386755 |
0.382698 |
0.800063 |
00:07 |
2 |
0.361672 |
0.373610 |
0.823194 |
00:07 |
3 |
0.345655 |
0.363278 |
0.829848 |
00:07 |
4 |
0.349931 |
0.362171 |
0.830957 |
00:07 |
| [0m 107 [0m | [0m 0.831 [0m | [0m 5.475 [0m | [0m 5.721 [0m | [0m 0.02659 [0m |
| [0m 108 [0m | [0m 0.8308 [0m | [0m 4.597 [0m | [0m 7.994 [0m | [0m 0.1318 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.391720 |
0.395161 |
0.819550 |
00:06 |
1 |
0.372859 |
0.387859 |
0.810203 |
00:07 |
2 |
0.367385 |
0.376952 |
0.812262 |
00:07 |
3 |
0.346810 |
0.365312 |
0.827155 |
00:07 |
4 |
0.341221 |
0.363528 |
0.829056 |
00:07 |
| [0m 109 [0m | [0m 0.8291 [0m | [0m 2.391 [0m | [0m 4.915 [0m | [0m 0.8695 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.391635 |
0.402138 |
0.804658 |
00:08 |
1 |
0.402425 |
0.468013 |
0.803866 |
00:09 |
2 |
0.368443 |
0.372927 |
0.823669 |
00:09 |
3 |
0.361579 |
0.364913 |
0.826679 |
00:09 |
4 |
0.347989 |
0.363416 |
0.828580 |
00:09 |
| [0m 110 [0m | [0m 0.8286 [0m | [0m 5.711 [0m | [0m 7.031 [0m | [0m 1.157 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.393225 |
0.414461 |
0.812738 |
00:07 |
1 |
0.377641 |
0.381910 |
0.820659 |
00:07 |
2 |
0.363979 |
0.375042 |
0.826838 |
00:07 |
3 |
0.349867 |
0.363547 |
0.830165 |
00:07 |
4 |
0.350035 |
0.364254 |
0.830006 |
00:07 |
| [0m 111 [0m | [0m 0.83 [0m | [0m 3.867 [0m | [0m 7.351 [0m | [0m 0.7373 [0m |
| [0m 112 [0m | [0m 0.8275 [0m | [0m 5.568 [0m | [0m 0.8565 [0m | [0m 0.9522 [0m |
| [0m 113 [0m | [0m 0.8311 [0m | [0m 5.553 [0m | [0m 1.45 [0m | [0m 0.0 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.404215 |
0.421293 |
0.817807 |
00:10 |
1 |
0.380753 |
0.388533 |
0.803074 |
00:09 |
2 |
0.362475 |
0.381199 |
0.820342 |
00:09 |
3 |
0.353416 |
0.368431 |
0.824620 |
00:09 |
4 |
0.335854 |
0.365448 |
0.825570 |
00:09 |
| [0m 114 [0m | [0m 0.8256 [0m | [0m 3.949 [0m | [0m 1.125 [0m | [0m 1.109 [0m |
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.397351 |
0.482881 |
0.766160 |
00:08 |
1 |
0.373190 |
0.372999 |
0.819075 |
00:08 |
2 |
0.354075 |
0.368455 |
0.826046 |
00:08 |
3 |
0.350617 |
0.362793 |
0.830323 |
00:08 |
4 |
0.347381 |
0.361980 |
0.831274 |
00:08 |
| [0m 115 [0m | [0m 0.8313 [0m | [0m 7.258 [0m | [0m 4.724 [0m | [0m 1.639 [0m |
=============================================================
optimizer.max['target']
0.8428390622138977
{key: 2**int(value)
for key, value in optimizer.max['params'].items()}
{'pow_n_a': 2, 'pow_n_d': 64, 'pow_n_steps': 1}
Out of memory dataset
If your dataset is so big it doesn't fit in memory, you can load a chunk of it each epoch.
df = pd.read_csv(path/'adult.csv')
df_main,df_valid = df.iloc[:-1000].copy(),df.iloc[-1000:].copy()
# choose size that fit in memory
dataset_size = 1000
# load chunk with your own code
def load_chunk():
return df_main.sample(dataset_size).copy()
df_small = load_chunk()
cat_names = ['workclass', 'education', 'marital-status', 'occupation',
'relationship', 'race', 'native-country', 'sex']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter()(range_of(df_small))
to = TabularPandas(df_small, procs, cat_names, cont_names, y_names="salary", y_block = CategoryBlock(),
splits=None, do_setup=True)
# save the validation set
to_valid = to.new(df_valid)
to_valid.process()
val_dl = TabDataLoader(to_valid.train)
len(to.train)
1000
class ReloadCallback(Callback):
def begin_epoch(self):
df_small = load_chunk()
to_new = to.new(df_small)
to_new.process()
trn_dl = TabDataLoader(to_new.train)
self.learn.dls = DataLoaders(trn_dl, val_dl).cuda()
dls = to.dataloaders()
emb_szs = get_emb_sz(to)
model = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=8, n_a=32, n_steps=1);
opt_func = partial(Adam, wd=0.01, eps=1e-5)
learn = Learner(dls, model, CrossEntropyLossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])
learn.add_cb(ReloadCallback());
learn.fit_one_cycle(10)
epoch |
train_loss |
valid_loss |
accuracy |
time |
0 |
0.587740 |
0.550544 |
0.756000 |
00:01 |
1 |
0.545411 |
0.515772 |
0.782000 |
00:01 |
2 |
0.484289 |
0.468586 |
0.813000 |
00:01 |
3 |
0.447111 |
0.435774 |
0.817000 |
00:01 |
4 |
0.449050 |
0.394715 |
0.819000 |
00:01 |
5 |
0.428863 |
0.382005 |
0.835000 |
00:01 |
6 |
0.382100 |
0.404258 |
0.826000 |
00:01 |
7 |
0.383915 |
0.376179 |
0.833000 |
00:01 |
8 |
0.389460 |
0.367857 |
0.834000 |
00:01 |
9 |
0.376486 |
0.367577 |
0.834000 |
00:01 |