Skip to main content

Deep Temporal Contrastive Clustering for time series

Project description

pytorch implementation of Deep Temporal Contrastive Clustering

Implementation of the Deep Temporal Contrastive Clustering. Inspired by 07zy's bit older Tensorflow based approach (unfortunately using deprecated Tensorflow version). Therefore I decided to implement it myself in pytorch.

Config

Easiest is to configure it using a YAML file. Example:

# torchdtcc config.yaml

model:
  input_dim:  1
  num_layers:  3
  num_clusters: 3
  hidden_dims:  [100, 50, 50]
  dilation_rates: [1, 4, 16]
  tau_I: 1.0
  tau_C: 1.0
  stable_svd: false

trainer:
  save_path: "dtcc_model.pth"
  learning_rate: 0.005
  weight_decay: 0
  lambda_cd: 0.001
  num_epochs: 200
  update_interval: 5

data:
  dataset_class: "torchdtcc.datasets.meat.MeatArffDataset"
  dataset_args:   # Arguments to initialize your dataset class
    files_path: "./data/meat/"
    # add more args as needed

  batch_size: 64


output:
  soft_clusters: "soft_clusters.npy"
  hard_clusters_argmax: "hard_clusters_argmax.npy"
  hard_clusters_kmeans: "hard_clusters_kmeans.npy"

device: "cuda"  # or "cpu"

Training

import yaml
from torchdtcc.dtcc.trainer import DTCCTrainer
from torchdtcc.dtcc.clustering import Clusterer
from torch.utils.data import DataLoader
from torchdtcc.datasets.meat.arff_meat import MeatArffDataset

import logging

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

# Prepare dataset
data_cfg = config.get("data", {})
dataset = MeatArffDataset(path=data_cfg['dataset_args']['files_path'])

model_cfg = config.get("model", {})
logging.info(f"STABLE SVD: {model_cfg['stable_svd']}")
 
trainer = DTCCTrainer.from_config(config, dataset)
save_path = config.get("trainer", {}).get("save_path", "")
model = trainer.run(save_path=save_path)

Usage

After training

...
# Assuming the training script above

dataloader = DataLoader(dataset, batch_size=data_cfg.get("batch_size", 64), shuffle=False)
clusterer = Clusterer(config["device"])
clusterer.set_model(model)
labels = clusterer.cluster(dataloader, method="kmeans")  # or "soft", "argmax"
print(f"resulting predictions:\n{labels}")

Load model

import yaml
from torchdtcc.dtcc.trainer import DTCCTrainer
from torchdtcc.dtcc.clustering import Clusterer
from torch.utils.data import DataLoader
from torchdtcc.datasets.meat.arff_meat import MeatArffDataset

import logging

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

# Prepare dataset and dataloader
data_cfg = config.get("data", {})
dataset = MeatArffDataset(path=data_cfg['dataset_args']['files_path'])
dataloader = DataLoader(dataset, batch_size=data_cfg.get("batch_size", 64), shuffle=False)

model_path = config.get("trainer", {}).get("save_path", "")
model_cfg = config.get("model", {})

clusterer = Clusterer()
clusterer.load_model(
    model_path=model_path,
    model_kwargs=model_cfg,
    device=config.get("device", "cuda")
)
labels = clusterer.cluster(dataloader, method="kmeans")  # or "soft", "argmax"
print(f"resulting predictions:\n{labels}")

Run evaluation

To run the clusterer and print the accuracy, NMI, ARI and RI scores.

# assuming you have run some clustering example above
clusterer.evaluate(dataloader, method="kmeans")

Use your own dataset

This is the definition of the base class for an augmented dataset. DTCC always expects an AugmentedDataset and you need to implement the augmentation.

import torch
from torch.utils.data import Dataset
from abc import abstractmethod

class AugmentedDataset(Dataset):
    def __init__(self, dataframe, feature_cols, target_col):
        self.X = dataframe[feature_cols].values.astype('float32')
        self.y = dataframe[target_col].astype('int64').values

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = torch.tensor(self.X[idx])
        if x.ndim == 1:
            x = x.unsqueeze(-1)  # add feature dimension if only batch and seq_len provided
        y = torch.tensor(self.y[idx])
        return x, y
    
    @abstractmethod
    def augmentation(self, batch_x):
        pass

You can use the augmentations submodule for implementing the augmentations. See the meat class as reference:

from scipy.io import arff
import pandas as pd
from torchdtcc.augmentations.basic import jitter
from torchdtcc.augmentations.helper import torch_augmentation_wrapper
from torchdtcc.datasets.augmented_dataset import AugmentedDataset

class MeatArffDataset(AugmentedDataset):
    def __init__(self, path):
        data_train, _ = arff.loadarff(path + 'Meat_TRAIN.arff')
        data_test, _ = arff.loadarff(path + 'Meat_TEST.arff')
        df_train = pd.DataFrame(data_train)
        df_test = pd.DataFrame(data_test)
        df = pd.concat([df_train, df_test], ignore_index=True)
        for col in df.select_dtypes([object]):
            df[col] = df[col].apply(lambda x: x.decode('utf-8') if isinstance(x, bytes) else x)
        feature_cols = [col for col in df.columns if col.startswith('att')]
        target_col = 'target'
        # Ensure target is int
        df[target_col] = df[target_col].astype(int)
        super().__init__(df, feature_cols, target_col)

    def augmentation(self, batch_x):        
        # Ensure [batch, seq_len, features]
        assert batch_x.ndim == 3, f"Input must be 3D, got {batch_x.shape}"
        
        x_aug = torch_augmentation_wrapper(jitter, batch_x)
        # x_augx = torch_augmentation_wrapper(scaling, x_aug) # Import scaling first

        return x_aug

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchdtcc-0.0.7.tar.gz (20.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchdtcc-0.0.7-py3-none-any.whl (23.5 kB view details)

Uploaded Python 3

File details

Details for the file torchdtcc-0.0.7.tar.gz.

File metadata

  • Download URL: torchdtcc-0.0.7.tar.gz
  • Upload date:
  • Size: 20.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for torchdtcc-0.0.7.tar.gz
Algorithm Hash digest
SHA256 c96a4dee2760aee2665c87cfaea4d3911f02351295700e338106859e2ddec26e
MD5 ab68cceca38c3e9999674f7e4a02c926
BLAKE2b-256 afa9a405b5d5cd48eb04a5e957cfd7d54c9084253d214b4479d283e4ab62542e

See more details on using hashes here.

File details

Details for the file torchdtcc-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: torchdtcc-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 23.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for torchdtcc-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 49d890bec620eba390b48ef72164528404669bc61c1049f4ee51a748b788ccfa
MD5 efe46ab54eff1f352bad54e59894baf9
BLAKE2b-256 ceeaadd76d1738a82a9051aaee5132f519df91782ccc1ce8fec0dc97b3e0cd55

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page