Skip to main content

Imbalanced Classification with Deep Reinforcement Learning.

Project description

imbDRL

GitHub Workflow Status License

Imbalanced Classification with Deep Reinforcement Learning.

This repository contains an (Double) Deep Q-Network implementation of binary classification on unbalanced datasets using TensorFlow 2.3+ and TF Agents 0.6+. The Double DQN as published in this paper by van Hasselt et al. (2015) is using a custom environment based on this paper by Lin, Chen & Qi (2019).

Example scripts on the Mnist, Fashion Mnist, Credit Card Fraud and Titanic datasets can be found in the ./imbDRL/examples/ddqn/ folder.

Requirements

  • Python 3.7+
  • The required packages as listed in: requirements.txt
  • Logs are by default saved in ./logs/
  • Trained models are by default saved in ./models/
  • Optional: ./data/ folder located at the root of this repository.
    • This folder must contain creditcard.csv downloaded from Kaggle if you would like to use the Credit Card Fraud dataset.
    • Note: creditcard.csv needs to be split in a seperate train and test file. Please use the function imbDRL.utils.split_csv

Getting started

Install via pip:

  • pip install imbDRL

Run any of the following scripts:

  • python .\imbDRL\examples\ddqn\train_credit.py
  • python .\imbDRL\examples\ddqn\train_famnist.py
  • python .\imbDRL\examples\ddqn\train_mnist.py
  • python .\imbDRL\examples\ddqn\train_titanic.py

TensorBoard

To enable TensorBoard, run tensorboard --logdir logs

Tests and linting

Extra arguments are handled with the ./tox.ini file.

  • Pytest: python -m pytest
  • Flake8: flake8
  • Coverage can be found in the generated ./htmlcov folder

Appendix

The appendix can be found in the imbDRLAppendix repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

imbDRL-2021.1.26.1.tar.gz (17.3 kB view hashes)

Uploaded Source

Built Distribution

imbDRL-2021.1.26.1-py3-none-any.whl (26.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page