Skip to main content

Add a regularization if the features/columns/neurons the hidden layer or output layer should be correlated. The vector with target correlation coefficient is computed before the optimization, and compared with correlation coefficients computed across the batch examples.

Project description

PyPI version PyPi downloads

keras-cor : Correlated Outputs Regularization

Add a regularization if the features/columns/neurons the hidden layer or output layer should be correlated. The vector with target correlation coefficient is computed before the optimization, and compared with correlation coefficients computed across the batch examples.

Usage

See demo notebook

from keras_cor import CorrOutputsRegularizer
import tensorflow as tf

# Simple regression NN
def build_mymodel(input_dim, target_corr, cor_rate=0.1, 
                  activation="sigmoid", output_dim=3):
    inputs = tf.keras.Input(shape=(input_dim,))
    h = tf.keras.layers.Dense(units=output_dim)(inputs)
    h = tf.keras.layers.Activation(activation)(h)
    outputs = CorrOutputsRegularizer(target_corr, cor_rate)(h)  # <= HERE
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

# Gneerate toy dataset
BATCH_SZ = 128
INPUT_DIM = 64
OUTPUT_DIM = 3

X_train = tf.random.normal([BATCH_SZ, INPUT_DIM])
y_train = tf.random.normal([BATCH_SZ, OUTPUT_DIM])

# Normally you should comput `target_corr` based on your target outputs `y_train`
# e.g., target_corr = tf.constant(y_train)
# However, you can also use subjective correlations (aka expert opinions), e.g.,
target_corr = tf.constant([.5, -.4, .9])

# Optimization
model = build_mymodel(input_dim=INPUT_DIM, target_corr=target_corr, output_dim=OUTPUT_DIM)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss="mean_squared_error")
history = model.fit(X_train, y_train, verbose=1, epochs=2)

# Inference
yhat = model.predict(X_train)
rhos = pearson_vec(yhat)
rhos

Appendix

Installation

The keras-cor git repo is available as PyPi package

pip install keras-cor
pip install git+ssh://git@github.com/ulf1/keras-cor.git

Install a virtual environment

python3 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt --no-cache-dir
pip install -r requirements-dev.txt --no-cache-dir
pip install -r requirements-demo.txt --no-cache-dir

(If your git repo is stored in a folder with whitespaces, then don't use the subfolder .venv. Use an absolute path without whitespaces.)

Python commands

  • Jupyter for the examples: jupyter lab
  • Check syntax: flake8 --ignore=F401 --exclude=$(grep -v '^#' .gitignore | xargs | sed -e 's/ /,/g')
  • Run Unit Tests: PYTHONPATH=. pytest

Publish

python setup.py sdist 
twine upload -r pypi dist/*

Clean up

find . -type f -name "*.pyc" | xargs rm
find . -type d -name "__pycache__" | xargs rm -r
rm -r .pytest_cache
rm -r .venv

Support

Please open an issue for support.

Contributing

Please contribute using Github Flow. Create a branch, add commits, and open a pull request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-cor-0.2.0.tar.gz (8.2 kB view details)

Uploaded Source

File details

Details for the file keras-cor-0.2.0.tar.gz.

File metadata

  • Download URL: keras-cor-0.2.0.tar.gz
  • Upload date:
  • Size: 8.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.9.6 requests/2.31.0 setuptools/59.6.0 requests-toolbelt/1.0.0 tqdm/4.65.0 CPython/3.10.6

File hashes

Hashes for keras-cor-0.2.0.tar.gz
Algorithm Hash digest
SHA256 59de74c7cf8e7626e549a9567467f6a504b66a07631fb5b0897ecf5ea9d2dfef
MD5 d1e4f83a0d4a63377e1f7aba6b9e597b
BLAKE2b-256 354fe26253b905f97f5d944ca56a8434097a1b709be759f9f37ff9a1fa99c963

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page