Skip to main content

No project description provided

Project description

Cognitive Model Discovery via Disentangled RNNs

Disentangled RNN (DisRNN) is a recurrent neural network architecture designed for discovering interpretable dynamical systems consistent with a dataset. It includes several architectural features that encourage simplicity, in the sense of having a small number of latent variables carrying independent information and updated in a sparse way.

We have explored fitting these to behavioral data from humans and other animals performing simple learning and decision-making tasks, and found that the resulting systems perform well as cognitive models and can readily be interpreted. You can read more about this work in our paper Cognitive Model Discovery via Disentangled RNNs.

The code here allows generating synthetic datasets, packaging laboratory datasets, training disRNNs with different hyperparameters as well as standard RNNs, and inspecting the fit networks.

Exploring DisRNN in Colab

We provide several ipynb notebooks you can use to explore DisRNN. The links below will open these notebooks in Google Colab. We recommend creating a copy so that you will be able to edit the notebook (File -> Save a copy in Drive), and connecting your notebook to a GPU or TPU backend (Connect button in the top right -> Connect to a hosted runtime)

  • The Train GRU notebook demonstrates fitting a synthetic using a gated recurrent unit (GRU) network. The GRU is a popular network architecture and, with the correct hyperparameters and a sufficiently large dataset, is expected to provide very good quality-of-fit in most situations.
  • The Train DisRNN notebook demonstrates fitting a synthetic dataset with a DisRNN network. It also demonstrates some of the tools available for inspecting the fit DisRNN and interpreting the resulting model.
  • The Train Multisubject DisRNN notebook demonstrates fitting a synthetic dataset containing data from multiple "individuals" which vary parametrically in their cognitive strategy. We use a "Multisubject DisRNN" to fit both similarities and differences using a single network. This combines ideas from DisRNN with prior ideas from the literature about disentangled subject embeddings (Dezfouli et al., 2019, Song et al., 2021)

Installing and running locally

These instruction assume you will be using a virtual environment created with conda.

  1. Create and activate the virtual environment
conda create --name disrnn_venv python=3.11
conda activate disrnn_venv
  1. Install the version of JAX suitable for your hardware

    • For CPU only: pip install -U "jax[cpu]"
    • For NVIDIA GPU: pip install -U "jax[cuda12]"
    • For other architectures: Consult the Official JAX Installation Guide.
  2. Clone the github repo and install remaining requirements

git clone https://github.com/google-deepmind/disentangled_rnns.git
!git clone https://github.com/google-deepmind/disentangled_rnns
%cd disentangled_rnns
!pip install .
%cd ..
  1. Test your setup using the example script
python example.py

Citing this work

If you use this code, please cite the following paper: Cognitive Model Discovery via Disentangled RNNs

@misc{miller_disRNN_2023,
  title = {Cognitive Model Discovery via Disentangled RNNs},
  author = {Miller, Kevin J and Eckstein, Maria and Botvinick, Matthew and Kurth-Nelson, Zeb},
  journal = {Neural Information Processing Systems},
  year = {2023},
}

License and disclaimer

Copyright 2023 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

disentangled_rnns-0.1.6.tar.gz (61.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

disentangled_rnns-0.1.6-py3-none-any.whl (74.7 kB view details)

Uploaded Python 3

File details

Details for the file disentangled_rnns-0.1.6.tar.gz.

File metadata

  • Download URL: disentangled_rnns-0.1.6.tar.gz
  • Upload date:
  • Size: 61.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for disentangled_rnns-0.1.6.tar.gz
Algorithm Hash digest
SHA256 30c48015daf8064f4fdd731e154abd2ef8aff832839c871cd138473bfd1932ef
MD5 9c107786feae82e97e298e41c6fc0911
BLAKE2b-256 4ba37f2bb3608b3a813927c634cc5dd5809ea5218eb26d0f0348054f82e0c1ed

See more details on using hashes here.

File details

Details for the file disentangled_rnns-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for disentangled_rnns-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 549e7d954870ecaac5ba01087d69efffbf081cdf9d2c79722933cb343dbed226
MD5 abcf66f02c47c42806c7b44ebe473d60
BLAKE2b-256 989d61de22e706930781155acc97c6b37be2285e7b69685791ed4582579db89f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page