Skip to main content

Imbalanced Classification with Deep Reinforcement Learning.

Project description

imbDRL

GitHub Workflow Status License

Imbalanced Classification with Deep Reinforcement Learning.

This repository contains multiple implementations of binary classification on unbalanced datasets using TensorFlow 2.3 and TF Agents 0.6:

  1. The Double Deep Q-network as published in this paper by van Hasselt et al. is using a custom environment based on this paper by Lin et al.

  2. The Neural Epsilon Greedy agent is based of this code from the TF Agents team.

Example scripts on the MNIST, IMDB and Credit Card Fraud datasets for both implementations can be found in the ./imbDRL/examples folder.

Requirements

  • Python 3.8
  • pip install -r requirements.txt
  • Optional: ./data/ folder located at the root of this repository.
    • This folder must contain creditcard.csv downloaded from Kaggle if you would like to use the Credit Card Fraud dataset.
    • Note: creditcard.csv needs to be split in a seperate train and test file. Please use the function imbDRL.utils.split_csv
  • Logs will be saved to ./logs/, trained models will be saved to ./models/

Getting started

  • For the DDQN examples:
    • python .\imbDRL\examples\ddqn\train_cartpole.py
    • python .\imbDRL\examples\ddqn\train_credit.py
    • python .\imbDRL\examples\ddqn\train_image.py
  • For the Bandit examples:
    • python .\imbDRL\examples\bandit\train_bandit_credit.py
    • python .\imbDRL\examples\bandit\train_bandit_image.py
    • python .\imbDRL\examples\bandit\train_bandit_imdb.py

TensorBoard

To enable TensorBoard, run tensorboard --logdir logs.

Tests and linting

Extra arguments are handled with the ./tox.ini file.

  • Pytest: python -m pytest
  • Flake8: flake8
  • Coverage can be found in the ./htmlcov folder

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

imbDRL-2020.11.24.2.tar.gz (19.6 kB view hashes)

Uploaded Source

Built Distribution

imbDRL-2020.11.24.2-py3-none-any.whl (35.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page