Skip to main content

No project description provided

Project description

Cognitive Model Discovery via Disentangled RNNs

Disentangled RNN (DisRNN) is a recurrent neural network architecture designed for discovering interpretable dynamical systems consistent with a dataset. It includes several architectural features that encourage simplicity, in the sense of having a small number of latent variables carrying independent information and updated in a sparse way.

We have explored fitting these to behavioral data from humans and other animals performing simple learning and decision-making tasks, and found that the resulting systems perform well as cognitive models and can readily be interpreted. You can read more about this work in our paper Cognitive Model Discovery via Disentangled RNNs.

The code here allows generating synthetic datasets, packaging laboratory datasets, training disRNNs with different hyperparameters as well as standard RNNs, and inspecting the fit networks.

Exploring DisRNN in Colab

We provide several ipynb notebooks you can use to explore DisRNN. The links below will open these notebooks in Google Colab. We recommend creating a copy so that you will be able to edit the notebook (File -> Save a copy in Drive), and connecting your notebook to a GPU or TPU backend (Connect button in the top right -> Connect to a hosted runtime)

  • The Train GRU notebook demonstrates fitting a synthetic using a gated recurrent unit (GRU) network. The GRU is a popular network architecture and, with the correct hyperparameters and a sufficiently large dataset, is expected to provide very good quality-of-fit in most situations.
  • The Train DisRNN notebook demonstrates fitting a synthetic dataset with a DisRNN network. It also demonstrates some of the tools available for inspecting the fit DisRNN and interpreting the resulting model.
  • The Train Multisubject DisRNN notebook demonstrates fitting a synthetic dataset containing data from multiple "individuals" which vary parametrically in their cognitive strategy. We use a "Multisubject DisRNN" to fit both similarities and differences using a single network. This combines ideas from DisRNN with prior ideas from the literature about disentangled subject embeddings (Dezfouli et al., 2019, Song et al., 2021)

Installing and running locally

These instruction assume you will be using a virtual environment created with conda.

  1. Create and activate the virtual environment
conda create --name disrnn_venv python=3.11
conda activate disrnn_venv
  1. Install the version of JAX suitable for your hardware

    • For CPU only: pip install -U "jax[cpu]"
    • For NVIDIA GPU: pip install -U "jax[cuda12]"
    • For other architectures: Consult the Official JAX Installation Guide.
  2. Clone the github repo and install remaining requirements

git clone https://github.com/google-deepmind/disentangled_rnns.git
!git clone https://github.com/google-deepmind/disentangled_rnns
%cd disentangled_rnns
!pip install .
%cd ..
  1. Test your setup using the example script
python example.py

Citing this work

If you use this code, please cite the following paper: Cognitive Model Discovery via Disentangled RNNs

@misc{miller_disRNN_2023,
  title = {Cognitive Model Discovery via Disentangled RNNs},
  author = {Miller, Kevin J and Eckstein, Maria and Botvinick, Matthew and Kurth-Nelson, Zeb},
  journal = {Neural Information Processing Systems},
  year = {2023},
}

License and disclaimer

Copyright 2023 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

disentangled_rnns-0.1.2.tar.gz (55.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

disentangled_rnns-0.1.2-py3-none-any.whl (69.5 kB view details)

Uploaded Python 3

File details

Details for the file disentangled_rnns-0.1.2.tar.gz.

File metadata

  • Download URL: disentangled_rnns-0.1.2.tar.gz
  • Upload date:
  • Size: 55.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for disentangled_rnns-0.1.2.tar.gz
Algorithm Hash digest
SHA256 a94b7f7e93900c117c8027e1bba46242f1998be4197e4d866bef5fb906ecde80
MD5 1574832c44cad7594c0a3c8c770210cc
BLAKE2b-256 916bb9a7b0b7db1fcc19825c55e0b3478d53863eae59bae5f669538321db64a2

See more details on using hashes here.

File details

Details for the file disentangled_rnns-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for disentangled_rnns-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 c7e3a8e2cfce47496817a39bf9b594b6624fca60d1d023954052444b23cbded1
MD5 892e6b4434726b199a03bafe42b7ea8e
BLAKE2b-256 d599707c5a2a01aacf417020d65f219ea85939a48860cfc063e6b460bc2fa475

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page