SERT Python Package
Project description
Sparse Encoder Representations from Transformers
sert
is a Python module for machine learning built on top of Tensorflow. It is designed for deep learning on sets, focused on predictive modeling tasks, including both regression and classification, on time series and non-time series data.
Many datasets, including standard tabular ones with p columns and N rows, can be represented as a set of observations in the format [row id, column name, value]. This is equivalent to pivoting a wide Nxp table into a long format:
Wide Dataframe:
Name | Apple | Banana | Cherry |
---|---|---|---|
Alice | 5 | 7 | NA |
Bob | 3 | 4 | 2 |
Long Dataframe:
Name | Fruit | Rating |
---|---|---|
Alice | Apple | 5 |
Alice | Banana | 7 |
Alice | Cherry | NA |
Bob | Apple | 3 |
Bob | Banana | 4 |
Bob | Cherry | 2 |
However, in some applications, storing data in a wide format is not feasible. For instance, with log data that's stored as [timestamp, user id, action], the data is inherently in a long format. Pivoting this data into a wide format isn't sensible due to the large number of unique values in the timestamp column.
sert
is designed to work with data in this long format, making it suitable for a variety of problems, including tabular, time series, and log data. The benefits of using long format data with sert
are:
-
It allows for the removal of individual observations with missing values without having to remove the entire row. As such,
sert
can handle missing values without requiring imputation or the removal of observed data from the same row. -
It can be applied to datasets that are best represented in a long format, such as log data or multivariate non-aligned time series data.
sert
leverages the powerful transformer architecture to learn from sets by discerning how to focus on important observations. By utilizing the transformer architecture, sert
is parallelizable and scalable to large multivariate time series datasets, unlike RNNs which are sequential and cannot be parallelized.
Installation
To install sert
, simply run:
pip install sert
Note: Dependencies will be installed automatically.
Dependencies
sert
requires:
- NumPy
- Pandas
- tensorflow (>= 2.0.0)
- keras_nlp (>= 0.6.0)
Quick Start
Below is an example of how to use sert
to classify irregular time series. For more examples and details on how to use the package for different problems, please refer to the vignettes.
from sert.models import TimeSERT
from sert.preprocessing import SeqDataPreparer
from sert.datasets import ts_classification
import tensorflow as tf
# Load the dataset
X_train, X_test, y_train, y_test = ts_classification()
# Prepare the data for SERT
token_cap = X_train.groupby('id').size().max()
processor = SeqDataPreparer(token_capacity=token_cap)
train_input = processor.fit_transform(
X_train, index='id', times='time', names='var_name', values='value')
test_input = processor.transform(X_test)
# Instantiate the model
model = TimeSERT(num_var=1,
emb_dim=15,
num_head=3,
ffn_dim=5,
num_repeat=2,
num_out=y_train.shape[1],
task='classification')
# Compile the model
categorical_loss = tf.keras.losses.CategoricalCrossentropy()
model.compile(optimizer='adam', loss=categorical_loss, metrics=['accuracy'])
# Fit the model
model.fit(train_input, y_train, epochs=100, batch_size=250)
# Predictions
y_pred = model.predict(test_input)
# Evaluate the model
_, accuracy = model.evaluate(test_input, y_test)
print(f"Test Accuracy: {accuracy:.2f}")
License
Distributed under the Apache License. See LICENSE
for more information.
Contact
- Name: Amin Shoari Nejad
- Email: amin.shoarinejad@gmail.com
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file sert-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: sert-0.0.1-py3-none-any.whl
- Upload date:
- Size: 22.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 36522b57dcb8e154b4454d198b01e24ee0a2af70f63b247ae5a2737f48526dcd |
|
MD5 | 6a949f4cc6d3273bdecfb16da77c1401 |
|
BLAKE2b-256 | 63b8f8ca7d0e6fd4432764974ca745d384d67be0700c3b1f3969e48be2838039 |