Skip to main content

Library for full residual deep network with attention layers

Project description

Library of Full Residual Deep Network with Attention Layers (fullresattn)

The python library of full residual deep network with attention layers (fullresattn). Current version just supports the KERAS package of deep learning and will extend to the others in the future.

Major modules

Model

  • fullresAttCoder: major class to obtain a full residual deep network with optional attention layers by setting the arguments. See the class and its member functions' help for details.
  • pmetrics: functions for regression metrics like rsquared and RMSE.

Data

  • data: function to access each of two datasets,
    sim': simulated dataset in the format of Pandas's Data Frame, 'pm2.5':string, the name for a real dataset of the 2015 PM2.5 and the relevant covariates for the Beijing-Tianjin-Tangshan area. It is sampled by the fraction of 0.8 from the the original dataset (stratified by the julian day). See this function's help for details.
  • simdata: function to simulate the test dataset,
    The simulated dataset generated according to the formula: y=x1+x2*np.sqrt(x3)+x4+np.power((x5/500),0.3)-x6+ np.sqrt(x7)+x8+noise See this function's help for details.

Installation

You can directly install it using the following command for the latest version:

  pip install fullresattn

Note for installation and use

Runtime requirements

fullresattn requires installation of Keras with support of Tensorflow or other backend system of deep learning (to support Keras). Also Pandas and Numpy should be installed.

Use case

The package provides two specific examples for use of full residual deep network with optional attention layers. See the following example.

License

The fullresattn is provided under a MIT license that can be found in the LICENSE file. By using, distributing, or contributing to this project, you agree to the terms and conditions of this license.

Test call

Load the packages

import fullresattn  
from fullresattn.model import r2KAuto,r2K
from keras.callbacks import ModelCheckpoint

Load the simulated sample dataset

simdata=fullresattn.data('sim')

Preprocess the data

tcol=['x'+str(i) for i in range(1,9)]
X=simdata[tcol].values
y=simdata['y'].values
y=y.reshape((y.shape[0],1))
scX = preprocessing.StandardScaler().fit(X)
scy = preprocessing.StandardScaler().fit(y)
Xn=scX.transform(X)
yn=scy.transform(y)

Sampling

x_train, x_test, y_train,y_test = train_test_split(Xn,yn,test_size=0.2)
x_train, x_valid, y_train,y_valid = train_test_split(x_train,y_train,test_size=0.2)

Model --Set the check point to check the validation

wtPath='/tmp/res_sim_wei.hdf5'
checkpointw=ModelCheckpoint(wtPath, monitor="loss",verbose=0, 
save_best_only=True, mode="min")

Call the model class

modelCls = fullresattn.model.fullresAttCoder(x_train.shape[1], [32,16,8,4],'relu'   
               1,inresidual=True,reg=keras.regularizers.l1_l2(0),batchnorm=True,
                 outnres=None,defact='linear',outputtype=0,nAttlayers=4)

Get the residual autoencoder network

resmodel = modelCls.resAutoNet()

Show the network model's topology

resmodel.summary()
resmodel.compile(optimizer="adam", loss= 'mean_squared_error',metrics=['mean_squared_error',r2KAuto])

Starting to train the model... ...

fhist_res=resmodel.fit(x_train, y_train, batch_size=128, epochs=200, verbose=1, 
    shuffle=True,validation_data=(x_valid, y_valid),callbacks=[checkpointw])

Test performance

Tests on the simulated dataset show that the full residual model with 4 attention layers increased validation R2 by about 4% for the model with no attention layers.

Collaboration

Welcome to contact Dr. Lianfa Li (Email: lspatial@gmail.com).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fullresattn-0.1.0.tar.gz (2.5 MB view hashes)

Uploaded Source

Built Distribution

fullresattn-0.1.0-py3-none-any.whl (5.0 MB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page