Skip to main content

Pretrained MNet model for classifying demetia subclasses (HV, AD, DLB, and iNPH)

Project description

Installation

$ pip install eeg-dementia-classification-MNet

Pretrained Weights

Pretrained weights are available on our Google Drive.

  1. Download 'pretrained_weights.tar.gz'.
  2. Extract the file using the following command:
$ tar xvf pretrained_weights.tar.gz
  1. Locate the extradcted 'pretrained_weights' directory at the working directory. As an illustration, the weight files (.pth) should be organized as follows:
./pretrained_weights/
├── AD_vs_DLB
│   ├── model_fold#0_epoch#045.pth
│   ├── model_fold#1_epoch#031.pth
│   ├── model_fold#2_epoch#029.pth
│   ├── model_fold#3_epoch#031.pth
│   └── model_fold#4_epoch#028.pth
├── AD_vs_DLB_vs_NPH
│   ├── model_fold#0_epoch#024.pth
│   ├── model_fold#1_epoch#035.pth
...

Usage

from eeg_dementia_classification_MNet import MNet_1000
import torch

## Parameters
DISEASE_TYPES = ["HV", "AD", "DLB", "NPH"]

## MNet
model = MNet_1000(DISEASE_TYPES, is_ensemble=True)
model.load_weights(i_fold=0)

## Feeds data
bs, n_chs, seq_len = 16, 19, 1000
x = torch.rand(bs, n_chs, seq_len)
y = model(x)

Contact

Please feel free to contact the author.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eeg_dementia_classification_MNet-1.0.0.tar.gz (6.4 kB view hashes)

Uploaded Source

Built Distribution

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page