Skip to main content

A Deep Learning Framework for ECG Processing Tasks Based on PyTorch

Project description

torch_ecg

pytest codeql formatting codecov PyPI DOI zenodo downloads license

ECG Deep Learning Framework Implemented using PyTorch.

Documentation (under development):

The system design is depicted as follows

Installation

torch_ecg requires Python 3.6+ and is available through pip:

python -m pip install torch-ecg

One can download the development version hosted at GitHub via

git clone https://github.com/DeepPSP/torch_ecg.git
cd torch_ecg
python -m pip install .

or use pip directly via

python -m pip install git+https://github.com/DeepPSP/torch_ecg.git

Main Modules

Augmenters

Click to expand!

Augmenters are classes (subclasses of torch Module) that perform data augmentation in a uniform way and are managed by the AugmenterManager (also a subclass of torch Module). Augmenters and the manager share a common signature of the formward method:

forward(self, sig:Tensor, label:Optional[Tensor]=None, *extra_tensors:Sequence[Tensor], **kwargs:Any) -> Tuple[Tensor, ...]:

The following augmenters are implemented:

  1. baseline wander (adding sinusoidal and gaussian noises)
  2. cutmix
  3. mixup
  4. random flip
  5. random masking
  6. random renormalize
  7. stretch-or-compress (scaling)
  8. label smooth (not actually for data augmentation, but has simimlar behavior)

Usage example (this example uses all augmenters except cutmix, each with default config):

import torch
from torch_ecg.cfg import CFG
from torch_ecg.augmenters import AugmenterManager

config = CFG(
    random=False,
    fs=500,
    baseline_wander={},
    label_smooth={},
    mixup={},
    random_flip={},
    random_masking={},
    random_renormalize={},
    stretch_compress={},
)
am = AugmenterManager.from_config(config)
sig, label, mask = torch.rand(2,12,5000), torch.rand(2,26), torch.rand(2,5000,1)
sig, label, mask = am(sig, label, mask)

Augmenters can be stochastic along the batch dimension and (or) the channel dimension (ref. the get_indices method of the Augmenter base class).

:point_right: Back to TOC

Preprocessors

Click to expand!

Also preprecessors acting on numpy arrays. Similarly, preprocessors are monitored by a manager

import torch
from torch_ecg.cfg import CFG
from torch_ecg._preprocessors import PreprocManager

config = CFG(
    random=False,
    resample={"fs": 500},
    bandpass={},
    normalize={},
)
ppm = PreprocManager.from_config(config)
sig = torch.rand(12,80000).numpy()
sig, fs = ppm(sig, 200)

The following preprocessors are implemented

  1. baseline removal (detrend)
  2. normalize (z-score, min-max, naïve)
  3. bandpass
  4. resample

For more examples, see the README file) of the preprecessors module.

:point_right: Back to TOC

Databases

Click to expand!

This module include classes that manipulate the io of the ECG signals and labels in an ECG database, and maintains metadata (statistics, paths, plots, list of records, etc.) of it. This module is migrated and improved from DeepPSP/database_reader

After migration, all should be tested again, the progression:

Database Source Tested
AFDB PhysioNet :heavy_check_mark:
ApneaECG PhysioNet :x:
CinC2017 PhysioNet :x:
CinC2018 PhysioNet :x:
CinC2020 PhysioNet :heavy_check_mark:
CinC2021 PhysioNet :heavy_check_mark:
LTAFDB PhysioNet :x:
LUDB PhysioNet :heavy_check_mark:
MITDB PhysioNet :heavy_check_mark:
SHHS NSRR :x:
CPSC2018 CPSC :heavy_check_mark:
CPSC2019 CPSC :heavy_check_mark:
CPSC2020 CPSC :heavy_check_mark:
CPSC2021 CPSC :heavy_check_mark:
SPH Figshare :heavy_check_mark:

NOTE that these classes should not be confused with a torch Dataset, which is strongly related to the task (or the model). However, one can build Datasets based on these classes, for example the Dataset for the The 4th China Physiological Signal Challenge 2021 (CPSC2021).

One can use the built-in Datasets in torch_ecg.databases.datasets as follows

from torch_ecg.databases.datasets.cinc2021 import CINC2021Dataset, CINC2021TrainCfg
config = deepcopy(CINC2021TrainCfg)
config.db_dir = "some/path/to/db"
dataset = CINC2021Dataset(config, training=True, lazy=False)

:point_right: Back to TOC

Implemented Neural Network Architectures

Click to expand!
  1. CRNN, both for classification and sequence tagging (segmentation)
  2. U-Net
  3. RR-LSTM

A typical signature of the instantiation (__init__) function of a model is as follows

__init__(self, classes:Sequence[str], n_leads:int, config:Optional[CFG]=None, **kwargs:Any) -> None

if a config is not specified, then the default config will be used (stored in the model_configs module).

Quick Example

A quick example is as follows:

import torch
from torch_ecg.utils.utils_nn import adjust_cnn_filter_lengths
from torch_ecg.model_configs import ECG_CRNN_CONFIG
from torch_ecg.models.ecg_crnn import ECG_CRNN

config = adjust_cnn_filter_lengths(ECG_CRNN_CONFIG, fs=400)
# change the default CNN backbone
# bottleneck with global context attention variant of Nature Communications ResNet
config.cnn.name="resnet_nature_comm_bottle_neck_gc"

classes = ["NSR", "AF", "PVC", "SPB"]
n_leads = 12
model = ECG_CRNN(classes, n_leads, config)

model(torch.rand(2, 12, 4000))  # signal length 4000, batch size 2

Then a model for the classification of 4 classes, namely "NSR", "AF", "PVC", "SPB", on 12-lead ECGs is created. One can check the size of a model, in terms of the number of parameters via

model.module_size

or in terms of memory consumption via

model.module_size_

Custom Model

One can adjust the configs to create a custom model. For example, the building blocks of the 4 stages of a TResNet backbone are basic, basic, bottleneck, bottleneck. If one wants to change the second block to be a bottleneck block with sequeeze and excitation (SE) attention, then

from copy import deepcopy

from torch_ecg.models.ecg_crnn import ECG_CRNN
from torch_ecg.model_configs import (
    ECG_CRNN_CONFIG,
    tresnetF, resnet_bottle_neck_se,
)

my_resnet = deepcopy(tresnetP)
my_resnet.building_block[1] = "bottleneck"
my_resnet.block[1] = resnet_bottle_neck_se

The convolutions in a TResNet are anti-aliasing convolutions, if one wants further to change the convolutions to normal convolutions, then

for b in my_resnet.block:
    b.conv_type = None

or change them to separable convolutions via

for b in my_resnet.block:
    b.conv_type = "separable"

Finally, replace the default CNN backbone via

my_model_config = deepcopy(ECG_CRNN_CONFIG)
my_model_config.cnn.name = "my_resnet"
my_model_config.cnn.my_resnet = my_resnet

model = ECG_CRNN(["NSR", "AF", "PVC", "SPB"], 12, my_model_config)

:point_right: Back to TOC

CNN Backbones

Click to expand!

Implemented

  1. VGG
  2. ResNet (including vanilla ResNet, ResNet-B, ResNet-C, ResNet-D, ResNeXT, TResNet, Stanford ResNet, Nature Communications ResNet, etc.)
  3. MultiScopicNet (CPSC2019 SOTA)
  4. DenseNet (CPSC2020 SOTA)
  5. Xception

In general, variants of ResNet are the most commonly used architectures, as can be inferred from CinC2020 and CinC2021.

Ongoing

  1. MobileNet
  2. DarkNet
  3. EfficientNet

TODO

  1. HarDNet
  2. HO-ResNet
  3. U-Net++
  4. U-Squared Net
  5. etc.

More details and a list of references can be found in the README file of this module.

:point_right: Back to TOC

Components

Click to expand!

This module consists of frequently used components such as loggers, trainers, etc.

Loggers

Loggers including

  1. CSV logger
  2. text logger
  3. tensorboard logger are implemented and manipulated uniformly by a manager.

Outputs

The Output classes implemented in this module serve as containers for ECG downstream task model outputs, including

  • ClassificationOutput
  • MultiLabelClassificationOutput
  • SequenceTaggingOutput
  • WaveDelineationOutput
  • RPeaksDetectionOutput

each having some required fields (keys), and is able to hold an arbitrary number of custom fields. These classes are useful for the computation of metrics.

Metrics

This module has the following pre-defined (built-in) Metrics classes:

  • ClassificationMetrics
  • RPeaksDetectionMetrics
  • WaveDelineationMetrics

These metrics are computed according to either Wikipedia, or some published literatures.

Trainer

An abstract base class BaseTrainer is implemented, in which some common steps in building a training pipeline (workflow) are impemented. A few task specific methods are assigned as abstractmethods, for example the method

evaluate(self, data_loader:DataLoader) -> Dict[str, float]

for evaluation on the validation set during training and perhaps further for model selection and early stopping.

:point_right: Back to TOC

:point_right: Back to TOC

Other Useful Tools

Click to expand!

R peaks detection algorithms

This is a collection of traditional (non deep learning) algorithms for R peaks detection collected from WFDB and BioSPPy.

:point_right: Back to TOC

Usage Examples

Click to expand!

See case studies in the benchmarks folder.

a large part of the case studies are migrated from other DeepPSP repositories, some are implemented in the old fasion, being inconsistent with the new system architecture of torch_ecg, hence need updating and testing

Benchmark Architecture Source Finished Updated Tested
CinC2020 CRNN DeepPSP/cinc2020 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CinC2021 CRNN DeepPSP/cinc2021 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CinC2022[^1] Multi Task Learning (MTL) DeepPSP/cinc2022 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CPSC2019 SequenceTagging/U-Net NA :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
CPSC2020 CRNN/SequenceTagging DeepPSP/cpsc2020 :heavy_check_mark: :x: :x:
CPSC2021 CRNN/SequenceTagging/LSTM DeepPSP/cpsc2021 :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:
LUDB U-Net NA :heavy_check_mark: :heavy_check_mark: :heavy_check_mark:

[^1]: Although CinC2022 dealt with acoustic cardiac signals (phonocardiogram, PCG), the tasks and signals can be treated similarly.

Taking CPSC2021 for example, the steps are

  1. Write a Dataset to fit the training data for the model(s) and the training workflow. Or directly use the built-in Datasets in torch_ecg.databases.datasets. In this example, 3 tasks are considered, 2 of which use a MaskedBCEWithLogitsLoss function, hence the Dataset produces an extra tensor for these 2 tasks

    def __getitem__(self, index:int) -> Tuple[np.ndarray, ...]:
        if self.lazy:
            if self.task in ["qrs_detection"]:
                return self.fdr[index][:2]
            else:
                return self.fdr[index]
        else:
            if self.task in ["qrs_detection"]:
                return self._all_data[index], self._all_labels[index]
            else:
                return self._all_data[index], self._all_labels[index], self._all_masks[index]
    
  2. Inherit a base model to create task specific models, along with tailored model configs

  3. Inherit the BaseTrainer to build the training pipeline, with the abstractmethods (_setup_dataloaders, run_one_step, evaluate, batch_dim, etc.) implemented.

:point_right: Back to TOC

CAUTION

For the most of the time, but not always, after updates, I will run the notebooks in the benchmarks manually. If someone finds some bug, please raise an issue. The test workflow is to be enhanced and automated, see this project.

:point_right: Back to TOC

Work in progress

See the projects page.

:point_right: Back to TOC

Citation

@misc{torch_ecg,
      title = {{torch\_ecg: An ECG Deep Learning Framework Implemented using PyTorch}},
     author = {WEN, Hao and KANG, Jingsu},
        doi = {10.5281/ZENODO.6435048},
        url = {https://zenodo.org/record/6435048},
  publisher = {Zenodo},
       year = {2022},
  copyright = {{MIT License}}
}
@article{torch_ecg_paper,
      title = {{A Novel Deep Learning Package for Electrocardiography Research}},
     author = {Hao Wen and Jingsu Kang},
    journal = {{Physiological Measurement}},
        doi = {10.1088/1361-6579/ac9451},
       year = {2022},
      month = {11},
  publisher = {{IOP Publishing}},
     volume = {43},
     number = {11},
      pages = {115006}
}

:point_right: Back to TOC

Thanks

Much is learned, especially the modular design, from the adversarial NLP library TextAttack and from Hugging Face transformers.

:point_right: Back to TOC

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_ecg-0.0.30.tar.gz (452.8 kB view details)

Uploaded Source

Built Distribution

torch_ecg-0.0.30-py3-none-any.whl (559.1 kB view details)

Uploaded Python 3

File details

Details for the file torch_ecg-0.0.30.tar.gz.

File metadata

  • Download URL: torch_ecg-0.0.30.tar.gz
  • Upload date:
  • Size: 452.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for torch_ecg-0.0.30.tar.gz
Algorithm Hash digest
SHA256 9769131ed1830d95e2f0d149aa1a56a6667fa4747713a6ff6a12e3cc65d39d42
MD5 87b8d744aae15ff7cc493bd1df0c302f
BLAKE2b-256 738d5c16e32e14d470826e8d69911d78c2ae27a2f6319ca5868a975a0ec6dac3

See more details on using hashes here.

File details

Details for the file torch_ecg-0.0.30-py3-none-any.whl.

File metadata

  • Download URL: torch_ecg-0.0.30-py3-none-any.whl
  • Upload date:
  • Size: 559.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for torch_ecg-0.0.30-py3-none-any.whl
Algorithm Hash digest
SHA256 eb9bbea1530d323e3b118446fbd39fd3011e9a3ffc8176ba3674c7d21e61ae1f
MD5 9ceb384c4a1256193aa2d2e2d815573e
BLAKE2b-256 0077de41324332bd6087d673208d9ef4f6fdf252b39b0e90966490da07e2330d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page