Skip to main content

Imbalanced Classification with Deep Reinforcement Learning.

Project description

imbDRL

GitHub Workflow Status License

Imbalanced Classification with Deep Reinforcement Learning.

This repository contains an (Double) Deep Q-Network implementation of binary classification on unbalanced datasets using TensorFlow 2.3+ and TF Agents 0.6+. The Double DQN as published in this paper by van Hasselt et al. (2015) is using a custom environment based on this paper by Lin, Chen & Qi (2019).

Example scripts on the Mnist, Fashion Mnist, Credit Card Fraud and Titanic datasets can be found in the ./imbDRL/examples/ddqn/ folder.

Requirements

  • Python 3.7+
  • The required packages as listed in: requirements.txt
  • Logs are by default saved in ./logs/
  • Trained models are by default saved in ./models/
  • Optional: ./data/ folder located at the root of this repository.
    • This folder must contain creditcard.csv downloaded from Kaggle if you would like to use the Credit Card Fraud dataset.
    • Note: creditcard.csv needs to be split in a seperate train and test file. Please use the function imbDRL.utils.split_csv

Getting started

Install via pip:

  • pip install imbDRL

Run any of the following scripts:

  • python .\imbDRL\examples\ddqn\train_credit.py
  • python .\imbDRL\examples\ddqn\train_famnist.py
  • python .\imbDRL\examples\ddqn\train_mnist.py
  • python .\imbDRL\examples\ddqn\train_titanic.py

TensorBoard

To enable TensorBoard, run tensorboard --logdir logs

Tests and linting

Extra arguments are handled with the ./tox.ini file.

  • Pytest: python -m pytest
  • Flake8: flake8
  • Coverage can be found in the generated ./htmlcov folder

Appendix

The appendix can be found in the imbDRLAppendix repository.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

imbDRL-2021.1.26.1.tar.gz (17.3 kB view details)

Uploaded Source

Built Distribution

imbDRL-2021.1.26.1-py3-none-any.whl (26.9 kB view details)

Uploaded Python 3

File details

Details for the file imbDRL-2021.1.26.1.tar.gz.

File metadata

  • Download URL: imbDRL-2021.1.26.1.tar.gz
  • Upload date:
  • Size: 17.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.8.7

File hashes

Hashes for imbDRL-2021.1.26.1.tar.gz
Algorithm Hash digest
SHA256 af08f83c774ca3b99cfa2d0f370e7632b205845f7585926278ea9ae41eb08607
MD5 fd5492ff5204340140f3a950a8c71cb8
BLAKE2b-256 38cb7ff4102a3de82530382fc5eaeb1af8d827ec0bf02cda0b5302546fffbfa2

See more details on using hashes here.

File details

Details for the file imbDRL-2021.1.26.1-py3-none-any.whl.

File metadata

  • Download URL: imbDRL-2021.1.26.1-py3-none-any.whl
  • Upload date:
  • Size: 26.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.8.7

File hashes

Hashes for imbDRL-2021.1.26.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a056095d182364f11bb39941db2964460d9d132bbc5aa441556415ea76836e0d
MD5 1c3bacfecd35f9ce41f9fe54878dd3af
BLAKE2b-256 11e37bf599b81bcdcc6ede6d0d430ba0560df712bb6e0c14294ddfc50d6c6997

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page