Skip to main content

Action segmentation framework built with PyTorch Lightning

Project description

Lightning Action

GitHub PyPI codecov

A modern action segmentation framework built with PyTorch Lightning for behavioral analysis.

Features

  • Modern Architecture: Built with PyTorch Lightning for scalable and reproducible training
  • Multiple Backbones: Support for TemporalMLP, RNN (LSTM/GRU), and Dilated TCN architectures
  • Command-line Interface: Easy-to-use CLI for training and inference
  • Comprehensive Logging: Built-in metrics tracking and visualization with TensorBoard
  • Extensive Testing: Full test coverage for reliable development

Installation

Prerequisites

  • Python 3.10+
  • PyTorch with CUDA support (for GPU training; optional for keypoint models, required for video models)

Install from Source

git clone https://github.com/paninski-lab/lightning-action.git
cd lightning-action
pip install -e .

Dependencies

Core dependencies include:

  • pytorch-lightning - Training framework
  • torch - Deep learning backend
  • numpy - Numerical computing
  • pandas - Data manipulation
  • scikit-learn - Machine learning utilities
  • tensorboard - Experiment tracking

Quick Start

The instructions below are for keypoint-based models. For video-based models, see docs/video_pipeline_quickstart.md.

1. Prepare Your Data

Organize your data in the following structure:

data/
├── markers/
│   ├── experiment1.csv
│   ├── experiment2.csv
│   └── ...
├── labels/
│   ├── experiment1.csv
│   ├── experiment2.csv
│   └── ...
└── features/  # optional, hand-crafted featurization of markers or other video representations
    ├── experiment1.csv
    ├── experiment2.csv
    └── ...

2. Create a Configuration File

Create a YAML configuration file (see configs/segmenter_example.yaml):

data:
  data_path: /path/to/your/data
  input_dir: markers
  transforms:  # optional, defaults to ZScore
    - ZScore

model:
  input_size: 10
  output_size: 4
  backbone: temporalmlp
  num_hid_units: 256
  num_layers: 2
  
optimizer:
  type: Adam
  lr: 1e-3
  
training:
  num_epochs: 100
  batch_size: 32
  device: cpu  # or 'gpu'

3. Train a Model

Using the CLI:

litaction train --config configs/my_config.yaml --output-dir runs/my_experiment

Using the Python API:

from lightning_action.api import Model

# Load model from config
model = Model.from_config('configs/my_config.yaml')

# Train model
model.train(output_dir='runs/my_experiment')

4. Generate Predictions

Using the CLI:

litaction predict --model-dir runs/my_experiment --data-dir /path/to/data --input-dir markers --output-dir predictions/

Using the Python API:

# Load trained model
model = Model.from_dir('runs/my_experiment')

# Generate predictions
model.predict(
    data_path='/path/to/data',
    input_dir='markers',
    output_dir='predictions/'
)

See configs/README.md for detailed configuration options.

Monitoring Training with TensorBoard

Lightning Action automatically logs training metrics to TensorBoard. To visualize your training progress:

  1. Launch TensorBoard after starting training:

    tensorboard --logdir /path/to/your/runs/directory
    
  2. Set the correct logdir: Use the deepest directory that contains all your model directories. For example:

    # If your models are in:
    # runs/experiment1/
    # runs/experiment2/
    # runs/baseline/
    
    # Launch TensorBoard with:
    tensorboard --logdir runs/
    
  3. Open your browser and navigate to http://localhost:6006 to view the TensorBoard dashboard.

  4. Available metrics include:

    • Training and validation loss
    • Training and validation accuracy
    • Training and validation F1 score
    • Learning rate schedules

Tip: Keep TensorBoard running while training multiple experiments to compare results in real-time.


Contributing

See CONTRIBUTING.md for guidelines on setting up a development environment, code style, and submitting pull requests.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use this framework in your research, please cite:

@article{blau2024study,
  title={A study of animal action segmentation algorithms across supervised, unsupervised, and semi-supervised learning paradigms},
  author={Blau, Ari and Schaffer, Evan S and Mishra, Neeli and Miska, Nathaniel J and Laboratory, International Brain and Paninski, Liam and Whiteway, Matthew R},
  journal={Neurons, behavior, data analysis, and theory},
  volume={2024},
  pages={10--51628},
  year={2024}
}

Acknowledgments

This framework is built upon the work of:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning_action-1.1.0.tar.gz (68.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lightning_action-1.1.0-py3-none-any.whl (89.6 kB view details)

Uploaded Python 3

File details

Details for the file lightning_action-1.1.0.tar.gz.

File metadata

  • Download URL: lightning_action-1.1.0.tar.gz
  • Upload date:
  • Size: 68.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.4.1 CPython/3.11.15 Linux/6.17.0-1013-azure

File hashes

Hashes for lightning_action-1.1.0.tar.gz
Algorithm Hash digest
SHA256 2bd21a9b27ca854086ebb29a14cf639a16b782de4f4433d492390b41626390cd
MD5 76c76c13e0071df1c731f6c98964720f
BLAKE2b-256 28d18ed02938e1a2a0214a4c03c4ff410659b32798a4bac2d7fba324355a2164

See more details on using hashes here.

File details

Details for the file lightning_action-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: lightning_action-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 89.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.4.1 CPython/3.11.15 Linux/6.17.0-1013-azure

File hashes

Hashes for lightning_action-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 75804909d79bdbb5983581abb2aeb6bedefbf0cf9b7e4fad915e487046ec0408
MD5 2d12b78ff4269200fea97441c0c77ac4
BLAKE2b-256 db797ad9c8984b2ac2a04abaaa88ff57038f6cfbcff4385628d29aa7017cff1d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page