Skip to main content

An easy-to-use api for the closed-form continuous models in tensorflow and pytorch.

Project description

Closed-form Continuous-time Models

Closed-form Continuous-time Neural Networks (CfCs) are powerful sequential neural information processing units.

Paper Open Access: https://www.nature.com/articles/s42256-022-00556-7

Arxiv: https://arxiv.org/abs/2106.13898

Requirements

  • Python3.6 or newer
  • Tensorflow 2.4 or newer
  • Pytorch
  • Numpy

For a fresh anaconda environment with the required dependencies:

conda env create --file environment.yml
conda activate cfc

Usage

Example

from cfc_model.dense_model import SequentialModel
X = np.array([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 0, 1], [1, 1, 0, 0],
              [1, 0, 1, 0], [1, 1, 0, 1], [1, 0, 0, 1], [1, 0, 1, 0]])
y = np.array([0, 0, 1, 1, 1, 0, 1, 1])
model = SequentialModel()
model.fit(X, y)
y_pred = model.predict([1, 1, 0, 1]) # y_pred equals 0

The following configuration states can be used

  • no_gate Runs the CfC without the (1-sigmoid) part
  • minimal Runs the CfC direct solution
  • use_ltc Runs an LTC with a semi-implicit ODE solver instead of a CfC
  • use_mixed Mixes the CfC's RNN-state with a LSTM to avoid vanishing gradients

If none of these flags are provided, the full CfC model is used

Example


from cfc_model.dense_model import SequentialModel
X = np.array([[1, 1, 1, 0], [1, 1, 0, 1], [1, 0, 0, 1], [1, 1, 0, 0],
              [1, 0, 1, 0], [1, 1, 0, 1], [1, 0, 0, 1], [1, 0, 1, 0]])
y = np.array([0, 0, 1, 1, 1, 0, 1, 1])
model = SequentialModel()

# Runs an LTC with a semi-implicit ODE solver instead of a CfC
config = {"use_ltc": True}
model.fit(X, y, config=config)
y_pred = model.predict([1, 1, 0, 1]) # y_pred equals 0

Cite

	title = {Closed-form continuous-time neural networks},
	journal = {Nature Machine Intelligence},
	author = {Hasani, Ramin and Lechner, Mathias and Amini, Alexander and Liebenwein, Lucas and Ray, Aaron and Tschaikowski, Max and Teschl, Gerald and Rus, Daniela},
  issn = {2522-5839},
	month = nov,
	year = {2022},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cfc_model-1.0.5.tar.gz (23.0 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cfc_model-1.0.5-py3-none-any.whl (23.0 MB view details)

Uploaded Python 3

File details

Details for the file cfc_model-1.0.5.tar.gz.

File metadata

  • Download URL: cfc_model-1.0.5.tar.gz
  • Upload date:
  • Size: 23.0 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.15

File hashes

Hashes for cfc_model-1.0.5.tar.gz
Algorithm Hash digest
SHA256 b2d271efe6f106b4e3dd948ba1e854d66f0d4d7d34d6f91c4f69e27f85a53ac7
MD5 92084a12de730e17ba47cc7b2088b805
BLAKE2b-256 ddba983a81250d9a3fb704b11eb466a7b9e0680bbb0800a32ab10da7daa8d712

See more details on using hashes here.

File details

Details for the file cfc_model-1.0.5-py3-none-any.whl.

File metadata

  • Download URL: cfc_model-1.0.5-py3-none-any.whl
  • Upload date:
  • Size: 23.0 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.0

File hashes

Hashes for cfc_model-1.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 1e3ae4c52bbf78d34945186cdf524ed35ae4e295d2a381bc1467e0808ad1fadb
MD5 24b6e84447681a0c06040bc126c50907
BLAKE2b-256 8e2d31d03ea9ac73515cc9a810b9a91afaaec02c180a110f82cfc8e9eb71baba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page