Imbalanced Classification with Deep Reinforcement Learning.
Project description
imbDRL
Imbalanced Classification with Deep Reinforcement Learning.
This repository contains an (Double) Deep Q-Network implementation of binary classification on unbalanced datasets using TensorFlow 2.3 / 2.4 and TF Agents 0.6:
- The Double Deep Q-network as published in this paper by van Hasselt et al. is using a custom environment based on this paper by Lin et al.
Example scripts on the Mnist, Fashion Mnist and Credit Card Fraud datasets can be found in the ./imbDRL/examples/ddqn/
folder.
Requirements
- Python 3.7+
pip install -r requirements.txt
- Logs are by default saved in
./logs/
- Trained models are by default saved in
./models/
- Optional:
./data/
folder located at the root of this repository.- This folder must contain
creditcard.csv
downloaded from Kaggle if you would like to use the Credit Card Fraud dataset. - Note:
creditcard.csv
needs to be split in a seperate train and test file. Please use the functionimbDRL.utils.split_csv
- This folder must contain
Getting started
Run any of the following scripts:
python .\imbDRL\examples\ddqn\train_credit.py
python .\imbDRL\examples\ddqn\train_famnist.py
python .\imbDRL\examples\ddqn\train_mnist.py
TensorBoard
To enable TensorBoard, run tensorboard --logdir logs
Tests and linting
Extra arguments are handled with the ./tox.ini
file.
- Pytest:
python -m pytest
- Flake8:
flake8
- Coverage can be found in the generated
./htmlcov
folder
Appendix
The appendix can be found in the imbDRLAppendix repository.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
imbDRL-2021.1.22.1.tar.gz
(16.5 kB
view hashes)
Built Distribution
Close
Hashes for imbDRL-2021.1.22.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2113258e595bb8ea2a4410f9be342ebb9725ee67a7cadccc58659a38a61450a0 |
|
MD5 | 450aaa46bc2de6fb0eca84908c6837ee |
|
BLAKE2b-256 | 850a95c2738959d6a6145b1c115c3cf5d94eeee510d7721200a1d544a14cc890 |