Imbalanced Classification with Deep Reinforcement Learning.
Project description
imbDRL
Imbalanced Classification with Deep Reinforcement Learning.
This repository contains multiple implementations of binary classification on unbalanced datasets using TensorFlow 2.3 and TF Agents 0.6:
-
The Double Deep Q-network as published in this paper by van Hasselt et al. is using a custom environment based on this paper by Lin et al.
-
The Neural Epsilon Greedy agent is based of this code from the TF Agents team.
Example scripts on the MNIST, IMDB and Credit Card Fraud datasets for both implementations can be found in the ./imbDRL/examples
folder.
Requirements
- Python 3.8
pip install -r requirements.txt
- Optional:
./data/
folder located at the root of this repository.- This folder must contain
creditcard.csv
downloaded from Kaggle if you would like to use the Credit Card Fraud dataset. - Note:
creditcard.csv
needs to be split in a seperate train and test file. Please use the functionimbDRL.utils.split_csv
- This folder must contain
- Logs will be saved to
./logs/
, trained models will be saved to./models/
Getting started
- For the DDQN examples:
python .\imbDRL\examples\ddqn\train_cartpole.py
python .\imbDRL\examples\ddqn\train_credit.py
python .\imbDRL\examples\ddqn\train_image.py
- For the Bandit examples:
python .\imbDRL\examples\bandit\train_bandit_credit.py
python .\imbDRL\examples\bandit\train_bandit_image.py
python .\imbDRL\examples\bandit\train_bandit_imdb.py
TensorBoard
To enable TensorBoard, run tensorboard --logdir logs
.
Tests and linting
Extra arguments are handled with the ./tox.ini
file.
- Pytest:
python -m pytest
- Flake8:
flake8
- Coverage can be found in the
./htmlcov
folder
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for imbDRL-2020.11.24.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b583f5d746b83bafa7bfc2ff8481f9f582baf24363647e34c499bfaeb47fd140 |
|
MD5 | 3837d00aaa869a545f4ad83a2ef5462a |
|
BLAKE2b-256 | da56b22f470fd1b9c41797557d5f0eea321360266ba3900f099dedd2b7a990fb |