Skip to main content

No project description provided

Project description

Cognitive Model Discovery via Disentangled RNNs

Disentangled RNN (DisRNN) is a recurrent neural network architecture designed for discovering interpretable dynamical systems consistent with a dataset. It includes several architectural features that encourage simplicity, in the sense of having a small number of latent variables carrying independent information and updated in a sparse way.

We have explored fitting these to behavioral data from humans and other animals performing simple learning and decision-making tasks, and found that the resulting systems perform well as cognitive models and can readily be interpreted. You can read more about this work in our paper Cognitive Model Discovery via Disentangled RNNs.

The code here allows generating synthetic datasets, packaging laboratory datasets, training disRNNs with different hyperparameters as well as standard RNNs, and inspecting the fit networks.

Exploring DisRNN in Colab

We provide several ipynb notebooks you can use to explore DisRNN. The links below will open these notebooks in Google Colab. We recommend creating a copy so that you will be able to edit the notebook (File -> Save a copy in Drive), and connecting your notebook to a GPU or TPU backend (Connect button in the top right -> Connect to a hosted runtime)

  • The Train GRU notebook demonstrates fitting a synthetic using a gated recurrent unit (GRU) network. The GRU is a popular network architecture and, with the correct hyperparameters and a sufficiently large dataset, is expected to provide very good quality-of-fit in most situations.
  • The Train DisRNN notebook demonstrates fitting a synthetic dataset with a DisRNN network. It also demonstrates some of the tools available for inspecting the fit DisRNN and interpreting the resulting model.
  • The Train Multisubject DisRNN notebook demonstrates fitting a synthetic dataset containing data from multiple "individuals" which vary parametrically in their cognitive strategy. We use a "Multisubject DisRNN" to fit both similarities and differences using a single network. This combines ideas from DisRNN with prior ideas from the literature about disentangled subject embeddings (Dezfouli et al., 2019, Song et al., 2021)

Installing and running locally

These instruction assume you will be using a virtual environment created with conda.

  1. Create and activate the virtual environment
conda create --name disrnn_venv python=3.11
conda activate disrnn_venv
  1. Install the version of JAX suitable for your hardware

    • For CPU only: pip install -U "jax[cpu]"
    • For NVIDIA GPU: pip install -U "jax[cuda12]"
    • For other architectures: Consult the Official JAX Installation Guide.
  2. Clone the github repo and install remaining requirements

git clone https://github.com/google-deepmind/disentangled_rnns.git
!git clone https://github.com/google-deepmind/disentangled_rnns
%cd disentangled_rnns
!pip install .
%cd ..
  1. Test your setup using the example script
python example.py

Citing this work

If you use this code, please cite the following paper: Cognitive Model Discovery via Disentangled RNNs

@misc{miller_disRNN_2023,
  title = {Cognitive Model Discovery via Disentangled RNNs},
  author = {Miller, Kevin J and Eckstein, Maria and Botvinick, Matthew and Kurth-Nelson, Zeb},
  journal = {Neural Information Processing Systems},
  year = {2023},
}

License and disclaimer

Copyright 2023 DeepMind Technologies Limited

All software is licensed under the Apache License, Version 2.0 (Apache 2.0); you may not use this file except in compliance with the Apache 2.0 license. You may obtain a copy of the Apache 2.0 license at: https://www.apache.org/licenses/LICENSE-2.0

All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: https://creativecommons.org/licenses/by/4.0/legalcode

Unless required by applicable law or agreed to in writing, all software and materials distributed here under the Apache 2.0 or CC-BY licenses are distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses.

This is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

disentangled_rnns-0.1.4.tar.gz (57.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

disentangled_rnns-0.1.4-py3-none-any.whl (71.0 kB view details)

Uploaded Python 3

File details

Details for the file disentangled_rnns-0.1.4.tar.gz.

File metadata

  • Download URL: disentangled_rnns-0.1.4.tar.gz
  • Upload date:
  • Size: 57.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for disentangled_rnns-0.1.4.tar.gz
Algorithm Hash digest
SHA256 140e5385565fba4eb410d9a5c635d83a728f659febfec8c8b0aa7a531e3537af
MD5 88a993d7831521a7f9f94e344de51ca3
BLAKE2b-256 21a099d9f3fb4f138af01b66b968af986f267c30261a0d94dbb586d1a674a062

See more details on using hashes here.

File details

Details for the file disentangled_rnns-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for disentangled_rnns-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 f8dcf3d09aab5ecac9bb48f0f77ad68033b3ac3c80f875e2e52df16f077e17a9
MD5 b618af43e948bcc5594aafbdbbc04a76
BLAKE2b-256 6f2ac1301baf6d1a61c5564d5964d3da57a3feb8e0198a0903d80b313a7a6f34

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page