Skip to main content

Deep Temporal Contrastive Clustering for time series

Project description

pytorch implementation of Deep Temporal Contrastive Clustering

Implementation of the Deep Temporal Contrastive Clustering. Inspired by 07zy's bit older Tensorflow based approach (unfortunately using deprecated Tensorflow version). Therefore I decided to implement it myself in pytorch.

Config

Easiest is to configure it using a YAML file. Example:

# torchdtcc config.yaml

model:
  input_dim:  1
  num_layers:  3
  num_clusters: 3
  hidden_dims:  [100, 50, 50]
  dilation_rates: [1, 4, 16]
  tau_I: 1.0
  tau_C: 1.0
  stable_svd: false

trainer:
  save_path: "dtcc_model.pth"
  learning_rate: 0.005
  weight_decay: 0
  lambda_cd: 0.001
  num_epochs: 200
  update_interval: 5

data:
  dataset_class: "torchdtcc.datasets.meat.MeatArffDataset"
  dataset_args:   # Arguments to initialize your dataset class
    files_path: "./data/meat/"
    # add more args as needed

  batch_size: 64


output:
  soft_clusters: "soft_clusters.npy"
  hard_clusters_argmax: "hard_clusters_argmax.npy"
  hard_clusters_kmeans: "hard_clusters_kmeans.npy"

device: "cuda"  # or "cpu"

Training

import yaml
from torchdtcc.dtcc.trainer import DTCCTrainer
from torchdtcc.dtcc.clustering import Clusterer
from torch.utils.data import DataLoader
from torchdtcc.datasets.meat.arff_meat import MeatArffDataset

import logging

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

# Prepare dataset
data_cfg = config.get("data", {})
dataset = MeatArffDataset(path=data_cfg['dataset_args']['files_path'])

model_cfg = config.get("model", {})
logging.info(f"STABLE SVD: {model_cfg['stable_svd']}")
 
trainer = DTCCTrainer.from_config(config, dataset)
save_path = config.get("trainer", {}).get("save_path", "")
model = trainer.run(save_path=save_path)

Usage

After training

...
# Assuming the training script above

dataloader = DataLoader(dataset, batch_size=data_cfg.get("batch_size", 64), shuffle=False)
clusterer = Clusterer(config["device"])
clusterer.set_model(model)
labels = clusterer.cluster(dataloader, method="kmeans")  # or "soft", "argmax"
print(f"resulting predictions:\n{labels}")

Load model

import yaml
from torchdtcc.dtcc.trainer import DTCCTrainer
from torchdtcc.dtcc.clustering import Clusterer
from torch.utils.data import DataLoader
from torchdtcc.datasets.meat.arff_meat import MeatArffDataset

import logging

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

# Prepare dataset and dataloader
data_cfg = config.get("data", {})
dataset = MeatArffDataset(path=data_cfg['dataset_args']['files_path'])
dataloader = DataLoader(dataset, batch_size=data_cfg.get("batch_size", 64), shuffle=False)

model_path = config.get("trainer", {}).get("save_path", "")
model_cfg = config.get("model", {})

clusterer = Clusterer()
clusterer.load_model(
    model_path=model_path,
    model_kwargs=model_cfg,
    device=config.get("device", "cuda")
)
labels = clusterer.cluster(dataloader, method="kmeans")  # or "soft", "argmax"
print(f"resulting predictions:\n{labels}")

Run evaluation

To run the clusterer and print the accuracy, NMI, ARI and RI scores.

# assuming you have run some clustering example above
clusterer.evaluate(dataloader, method="kmeans")

Use your own dataset

This is the definition of the base class for an augmented dataset. DTCC always expects an AugmentedDataset and you need to implement the augmentation.

import torch
from torch.utils.data import Dataset
from abc import abstractmethod

class AugmentedDataset(Dataset):
    def __init__(self, dataframe, feature_cols, target_col):
        self.X = dataframe[feature_cols].values.astype('float32')
        self.y = dataframe[target_col].astype('int64').values

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = torch.tensor(self.X[idx])
        if x.ndim == 1:
            x = x.unsqueeze(-1)  # add feature dimension if only batch and seq_len provided
        y = torch.tensor(self.y[idx])
        return x, y
    
    @abstractmethod
    def augmentation(self, batch_x):
        pass

You can use the augmentations submodule for implementing the augmentations. See the meat class as reference:

from scipy.io import arff
import pandas as pd
from torchdtcc.augmentations.basic import jitter
from torchdtcc.augmentations.helper import torch_augmentation_wrapper
from torchdtcc.datasets.augmented_dataset import AugmentedDataset

class MeatArffDataset(AugmentedDataset):
    def __init__(self, path):
        data_train, _ = arff.loadarff(path + 'Meat_TRAIN.arff')
        data_test, _ = arff.loadarff(path + 'Meat_TEST.arff')
        df_train = pd.DataFrame(data_train)
        df_test = pd.DataFrame(data_test)
        df = pd.concat([df_train, df_test], ignore_index=True)
        for col in df.select_dtypes([object]):
            df[col] = df[col].apply(lambda x: x.decode('utf-8') if isinstance(x, bytes) else x)
        feature_cols = [col for col in df.columns if col.startswith('att')]
        target_col = 'target'
        # Ensure target is int
        df[target_col] = df[target_col].astype(int)
        super().__init__(df, feature_cols, target_col)

    def augmentation(self, batch_x):        
        # Ensure [batch, seq_len, features]
        assert batch_x.ndim == 3, f"Input must be 3D, got {batch_x.shape}"
        
        x_aug = torch_augmentation_wrapper(jitter, batch_x)
        # x_augx = torch_augmentation_wrapper(scaling, x_aug) # Import scaling first

        return x_aug

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchdtcc-0.0.4.tar.gz (19.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchdtcc-0.0.4-py3-none-any.whl (22.5 kB view details)

Uploaded Python 3

File details

Details for the file torchdtcc-0.0.4.tar.gz.

File metadata

  • Download URL: torchdtcc-0.0.4.tar.gz
  • Upload date:
  • Size: 19.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for torchdtcc-0.0.4.tar.gz
Algorithm Hash digest
SHA256 d17d5de108d3c3b382add40b97e50cfb2487c06717b776ce2d761303e63f7c3e
MD5 c369aad5895f1f24890e156033a1bd5f
BLAKE2b-256 ce380991ebf32b08bd50eedf3c01705e6cf03b52619ad0f43ec0e055c80fb387

See more details on using hashes here.

File details

Details for the file torchdtcc-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: torchdtcc-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 22.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for torchdtcc-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 c7544c53508c3309544d095ca981130b94299861aac878aaca73ac3e67ff2162
MD5 54d87e01d6c6413ecff1e50e877ebfb5
BLAKE2b-256 f13017de0dd3f7c7f635da0c5762eddce70428c6ef3a6818cdae7118938fee31

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page