Skip to main content

COLA - Competitive layers for deep learning.

Project description

COLA (COmpetitive LAyers) is a Python package providing the implementation of gradient-based competitive layers which can be used on top of deep learning models for unsupervised tasks.

https://github.com/pietrobarbiero/deep-topological-learning/blob/master/deep_dual_figure.png

Theory

Theoretical foundations can be found in our paper.

If you find COLA useful in your research, please consider citing the following paper:

@misc{barbiero2020topological,
    title={Topological Gradient-based Competitive Learning},
    author={Pietro Barbiero and Gabriele Ciravegna and Vincenzo Randazzo and Giansalvo Cirrincione},
    year={2020},
    eprint={2008.09477},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}

Examples

Dual Competitive Layer (DCL)

https://github.com/pietrobarbiero/deep-topological-learning/blob/master/test/test-results/circles_dynamic_dual.png
https://github.com/pietrobarbiero/deep-topological-learning/blob/master/test/test-results/circles_scatter_dual.png

Vanilla Competitive Layer (VCL)

https://github.com/pietrobarbiero/deep-topological-learning/blob/master/test/test-results/circles_dynamic_vanilla.png
https://github.com/pietrobarbiero/deep-topological-learning/blob/master/test/test-results/circles_scatter_vanilla.png

Using COLA

from cola import DualModel, plot_confusion_matrix, scatterplot, scatterplot_dynamic

X, y = ... # load dataset

# load custom tensorflow layers
inputs = Input(shape=(d,), name='input')
...
outputs = ...

# instantiate the dual model
n = X.shape[0] # number of samples
k = ... # upper bound of the desired number of prototypes
model = DualModel(n_samples=n, k_prototypes=k, inputs=inputs, outputs=outputs, deep=False)
model.compile(optimizer=optimizer)
model.fit(X, y, epochs=epochs)

# plot prototype dynamics
plt.figure()
scatterplot_dynamic(X, model.prototypes_, y, valid=True)
plt.show()

# plot confusion matrix
# considering the prototypes estimated in the last epoch
plt.figure()
plot_confusion_matrix(x_pred, model.prototypes[-1], y)
plt.show()

# plot estimated topology
# considering the prototypes estimated in the last epoch
plt.figure()
scatterplot(x_pred, model.prototypes[-1], y, valid=True)
plt.show()

Authors

Pietro Barbiero

Licence

Copyright 2020 Pietro Barbiero.

Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and limitations under the License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepcola-0.0.0.tar.gz (13.9 kB view details)

Uploaded Source

Built Distribution

deepcola-0.0.0-py3-none-any.whl (16.1 kB view details)

Uploaded Python 3

File details

Details for the file deepcola-0.0.0.tar.gz.

File metadata

  • Download URL: deepcola-0.0.0.tar.gz
  • Upload date:
  • Size: 13.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/45.2.0 requests-toolbelt/0.8.0 tqdm/4.30.0 CPython/3.8.2

File hashes

Hashes for deepcola-0.0.0.tar.gz
Algorithm Hash digest
SHA256 be10bde18d9c2170189889f63e844dcccbde8c28a8f2858adb537bc529af0e01
MD5 457bdffc8add1bf63cfb9a51bf59bcf3
BLAKE2b-256 00a792cf586b99fc1d6b8f63fbdacf8f6113b1368dde48d3145a4b3f3e3f0253

See more details on using hashes here.

File details

Details for the file deepcola-0.0.0-py3-none-any.whl.

File metadata

  • Download URL: deepcola-0.0.0-py3-none-any.whl
  • Upload date:
  • Size: 16.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.4.2 requests/2.22.0 setuptools/45.2.0 requests-toolbelt/0.8.0 tqdm/4.30.0 CPython/3.8.2

File hashes

Hashes for deepcola-0.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3a6b120d397ccb61fc35c659ce0694ec99e9042d1d89b6da126b5785c84d92dc
MD5 e08f258f8066a6bb9bda33ef0cc7decf
BLAKE2b-256 f7eb397fe1eeef2c425c6d98272efb001454de0f5933355c21d46463a04fa87c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page