Tensorflow wavelet Layers
Project description
tensorflow-wavelets is an implementation of Custom Layers for Neural Networks:
- Discrete Wavelets Transform Layer
- Duel Tree Complex Wavelets Transform Layer
- Multi Wavelets Transform Layer
Installation
pip install tensorflow-wavelets
Usage
import tensorflow_wavelets.Layers.DWT as DWT
import tensorflow_wavelets.Layers.DTCWT as DTCWT
import tensorflow_wavelets.Layers.DMWT as DMWT
# Custom Activation function Layer
import tensorflow_wavelets.Layers.Threshold as Threshold
Examples
DWT(name="haar", concat=0)
"name" can be found in pywt.wavelist(family)
concat = 0 means to split to 4 smaller layers
from tensorflow import keras
model = keras.Sequential()
model.add(keras.Input(shape=(28, 28, 1)))
model.add(DWT.DWT(name="haar",concat=0))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(nb_classes, activation="softmax"))
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dwt_9_haar (DWT) (None, 14, 14, 4) 0
_________________________________________________________________
flatten_9 (Flatten) (None, 784) 0
_________________________________________________________________
dense_9 (Dense) (None, 10) 7850
=================================================================
Total params: 7,850
Trainable params: 7,850
Non-trainable params: 0
_________________________________________________________________
name = "db4" concat = 1
model = keras.Sequential()
model.add(layers.InputLayer(input_shape=(28, 28, 1)))
model.add(DWT(name="db4", concat=1))
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dwt_db4 (DWT) (None, 34, 34, 1) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
DMWT
functional example with Sure Threshold
from tensorflow.keras import layers
x_inp = layers.Input(shape=(512, 512, 1))
x = DMWT("ghm")(x_inp)
x = Threshold.Threshold(algo='sure', mode='hard')(x) # use "soft" or "hard"
x = IDMWT("ghm")(x)
model = Model(x_inp, x, name="MyModel")
model.summary()
Model: "MyModel"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 512, 512, 1)] 0
_________________________________________________________________
dmwt (DMWT) (None, 1024, 1024, 1) 0
_________________________________________________________________
sure_threshold (SureThreshol (None, 1024, 1024, 1) 0
_________________________________________________________________
idmwt (IDMWT) (None, 512, 512, 1) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
Free Software, Hell Yeah!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
tensorflow-wavelets-1.0.27.tar.gz
(19.6 kB
view hashes)
Built Distribution
Close
Hashes for tensorflow-wavelets-1.0.27.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | e232bf2600ee4c58c9142ede92c9922e35888466b8849e7de313c3eee3dda501 |
|
MD5 | 37dd6368aa243ad4a3b92848c11f54cc |
|
BLAKE2b-256 | 0f8410780520958b37687c944580009fec19025f5bac880e6a130c925750b8d2 |
Close
Hashes for tensorflow_wavelets-1.0.27-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 77d25a13474b71b373b4b399a9291ed4d58f6b9d15c362ffb3967ef5c9fbb925 |
|
MD5 | 4acbcfb1dd2530c26ab09e1021f21596 |
|
BLAKE2b-256 | cbf626cd831ca70726209afbdeb99ac0a6610f9acd4b83dca3fe8302f2d0c7e3 |