Expectation Reflection for classification
Project description
Expectation Reflection (ER) is a multiplicative optimization method that trains the interaction weights from features to target according to the ratio of target observations to their corresponding model expectations. This approach completely separates model updates from minimization of a cost function measuring goodness of fit, so that this cost function can be used as the stopping criterion of the iteration. Therefore, this method has advantage in dealing with the problems of small sample sizes (but many features). Using only one hyper-parameter is another benefit of this method.
Installation
From PyPI
pip install expectation-reflection
From Repository
git clone https://github.com/danhtaihoang/expectation-reflection.git
Usage
- Import
expectation_reflection
package into python script:
from expectation_reflection import classication as ER
- Select model:
model = ER.model(max_iter,regu)
- Import your
dataset.txt
into python script. In the binary classification task, ER takes target of {0, 1} form:
Xy = np.loadtxt('dataset.txt')
- Select features and target from the dataset. If target is the last column then
X, y = Xy[:,:-1], Xy[:,-1]
- Import
train_test_split
fromsklearn
to split data into train and test sets:
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.5,random_state = 1)
- Train the model with
(X_train, y_train)
set:
model.fit(X_train, y_train)
- Predict outputs
y_pred
and their probabilityp_pred
of new inputsX_test
:
y_pred = model.predict(X_test)
print('predicted output:', y_pred)
p_pred = model.predict_proba(X_test)
print('predicted probability:', p_pred)
- Intercept and interaction weights:
print('intercept:', model.intercept_)
print('interaction weights:', model.coef_)
Hyper-Parameter Optimization
ER has only one hyper-parameter, regu
which can be optimized by using GridSearchCV
from sklearn
:
from sklearn.model_selection import GridSearchCV
model = ER.model(max_iter=100)
regu = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.]
hyper_parameters = dict(regu=regu)
clf = GridSearchCV(model, hyper_parameters, cv=4, iid='deprecated')
best_model = clf.fit(X_train, y_train)
- Best hyper-parameters:
print('best_hyper_parameters:',best_model.best_params_)
- Predict outputs
y_pred
and their probabilityp_pred
:
y_pred = best_model.best_estimator_.predict(X_test)
print('predicted output:', y_pred)
p_pred = best_model.best_estimator_.predict_proba(X_test)
print('predicted probability:', p_pred)
Performance Evaluation
We can measure the performance of model by using metrics
from sklearn
:
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,\
roc_auc_score,roc_curve,auc
acc = accuracy_score(y_test,y_pred)
print('accuracy:', acc)
precision = precision_score(y_test,y_pred)
print('precision:', precision)
recall = recall_score(y_test,y_pred)
print('recall:', recall)
f1score = f1_score(y_test,y_pred)
print('f1score:', f1score)
roc_auc = roc_auc_score(y_test,p_pred) ## note: it is p_pred, not y_pred
print('roc auc:', roc_auc)
ROC AUC can be also calculated as
fp,tp,thresholds = roc_curve(y_test, p_pred, drop_intermediate=False)
roc_auc = auc(fp,tp)
print('roc auc:', roc_auc)
Citation
Please cite the following papers if you use this package in your work:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for expectation_reflection-0.0.5.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 53e38404e7b5780df8abdf5466bdb66aa75c4e81e0c09846fe27fbbc6fff450d |
|
MD5 | d432eb691fe4b2aba6c57804ba0ef64d |
|
BLAKE2b-256 | 9e213b118493b4ee47a6968acb42d38705bac8aa266f970f2a68e1895c0dbf44 |
Hashes for expectation_reflection-0.0.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3949034bdfa5557a790d5f224a1c9f97187a131b805cbf16370168fbe425cf34 |
|
MD5 | 75340ed2832a9991f558670483654e15 |
|
BLAKE2b-256 | 0a9954b3871a44f3ee25b03493879df3359824126c624e6c38fe370d5df52f0e |