Transformer Network for Time-Series and Wearable Sensor Data
Project description
Sensor Transformer (SeT)
Adaptation of Vision Transformer (ViT) for Time-Series and Sensor Data in Tensorflow.
Usage
import argparse
import tensorflow as tf
import sensortransformer import set_network
parser = argparse.ArgumentParser()
parser.add_argument("--signal-length", type=int)
parser.add_argument("--segment-size", type=int)
parser.add_argument("--num_channels", type=int)
parser.add_argument("--num_classes", type=int)
args = parser.parse_args()
"""
TF-Data objects. Instances must be of shape
x = (batch, signal_length, num_channels), y = (batch, num_classes)
"""
ds_train, ds_test = ...
model = set_network.SensorTransformer(
signal_length=args.signal_length,
segment_size=args.segment_size,
channels=args.num_channels,
num_classes=args.num_classes,
num_layers=4,
d_model=64,
num_heads=4,
mlp_dim=64,
dropout=0.1,
)
model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.keras.metrics.CategoricalAccuracy()],
)
model.fit(ds_train, epochs=50, verbose=1)
model.evaluate(ds_test)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for sensortransformer-0.1.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 711e334390d5abcc9a4826f2d314e8e5d6839e7d5d4c2731110f55a1df28676e |
|
MD5 | ba69f28446f7a0235a3321539ce3ed5f |
|
BLAKE2b-256 | 3dbd3b89f5aecd3e603288674560fecd7dfd15b7242aa973cd2f1165b3bd8151 |