Skip to main content

A comprehensive PyTorch Lightning framework for medical image classification with support for 2D/3D images

Project description

MedVision-Classification

MedVision-Classification 是一个基于 PyTorch Lightning 的医学影像分类框架,提供了训练和推理的简单接口。

特点

  • 基于 PyTorch Lightning 的高级接口
  • 支持常见的医学影像格式(NIfTI、DICOM 等)
  • 内置多种分类模型架构(ResNet、DenseNet、EfficientNet 等)
  • 灵活的数据加载和预处理管道
  • 模块化设计,易于扩展
  • 命令行界面用于训练和推理
  • 支持二分类和多分类任务

安装

系统要求

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA (可选,用于GPU加速)

基本安装

最简单的安装方式:

pip install -e .

从源码安装

git clone https://github.com/yourusername/medvision-classification.git
cd medvision-classification
pip install -e .

使用requirements文件

# 基本环境
pip install -r requirements.txt

# 开发环境
pip install -r requirements-dev.txt

使用conda环境

推荐使用 conda 创建独立的虚拟环境:

# 创建并激活环境
conda env create -f environment.yml
conda activate medvision-cls

# 安装项目本身
pip install -e .

快速入门

训练2D模型

medvision-cls train configs/train_config.yml

训练3D模型

medvision-cls train configs/train_3d_resnet_config.yml

测试模型

medvision-cls test configs/test_config.yml

推理

MedVision-cls predict configs/inference_config.yml --input /path/to/image --output /path/to/output

配置格式

2D分类训练配置示例

# 2D ResNet Training Configuration
seed: 42

task_dim: 2d

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet50"
    pretrained: true
  num_classes: 4

  # Metrics to compute
  metrics:
    accuracy:
      type: "accuracy"
    f1:
      type: "f1"
    precision:
      type: "precision"
    recall:
      type: "recall"
    auc:
      type: "auroc"
        
  # Loss configuration
  loss:
    type: "cross_entropy"
    weight: null
    label_smoothing: 0.0
  
  # Optimizer configuration
  optimizer:
    type: "adam"
    lr: 0.001
    weight_decay: 0.0001
  
  # Scheduler configuration
  scheduler:
    type: "cosine"
    T_max: 100
    eta_min: 0.00001

# Data configuration
data:
  type: "medical"
  batch_size: 4
  num_workers: 4
  data_dir: "data/classification"
  image_format: "*.png"
  
  # Transform configuration for 2D data
  transforms:
    image_size: [224, 224]
    normalize: true
    augment: true
    
  # Data split configuration
  train_val_split: [0.8, 0.2]
  seed: 42

# Training configuration
training:
  max_epochs: 10
  accelerator: "gpu"
  devices: [0,1,2,3]  # Multi-GPU training
  precision: 16
  save_metrics: true
  
  # Callbacks
  model_checkpoint:
    monitor: "val/accuracy"
    mode: "max"
    save_top_k: 3
    filename: "epoch_{epoch:02d}-val_acc_{val/accuracy:.3f}"

# Validation configuration
validation:
  check_val_every_n_epoch: 1

# Class names
class_names:
  - "Class_0"
  - "Class_1"

# Output paths
outputs:
  output_dir: "outputs"

# Logging
logging:
  log_every_n_steps: 10
  wandb:
    enabled: false
    project: "medvision-2d-classification"
    entity: null

3D分类训练配置示例

# 3D ResNet Training Configuration
seed: 42

task_dim: 3D

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet3d_18"  # Options: resnet3d_18, resnet3d_34, resnet3d_50
    pretrained: false    # No pretrained weights for 3D models
    in_channels: 3       # Input channels (typically 1 for medical images)
    dropout: 0.1
  num_classes: 2

  # Metrics to compute
  metrics:
    accuracy:
      type: "accuracy"
    f1:
      type: "f1"
    precision:
      type: "precision"
    recall:
      type: "recall"
    auc:
      type: "auroc"

  # Loss configuration
  loss:
    type: "cross_entropy"
    weight: null
    label_smoothing: 0.0
  
  # Optimizer configuration
  optimizer:
    type: "adam"
    lr: 0.001
    weight_decay: 0.0001
  
  # Scheduler configuration
  scheduler:
    type: "cosine"
    T_max: 100
    eta_min: 0.00001

# Data configuration
data:
  type: "medical"
  batch_size: 4         # Smaller batch size for 3D data
  num_workers: 4
  data_dir: "data/3D"
  image_format: "*.nii.gz"  # 3D medical image format
  
  # Transform configuration for 3D data
  transforms:
    image_size: [64, 64, 64]  # [D, H, W] for 3D volumes
    normalize: true
    augment: true
    
  # Data split configuration
  train_val_split: [0.8, 0.2]
  seed: 42

# Training configuration
training:
  max_epochs: 5
  accelerator: "gpu"
  devices: 1            # Single GPU for 3D (memory intensive)
  precision: 16         # Use mixed precision to save memory
  
  # Callbacks
  early_stopping:
    monitor: "val/loss"
    patience: 10
    mode: "min"
  
  model_checkpoint:
    monitor: "val/accuracy"
    mode: "max"
    save_top_k: 3
    filename: "epoch_{epoch:02d}-val_acc_{val/accuracy:.3f}"

# Validation configuration
validation:
  check_val_every_n_epoch: 1

# Output paths
outputs:
  output_dir: "outputs"

# Logging
logging:
  log_every_n_steps: 10
  wandb:
    enabled: false
    project: "medvision-3d-classification"
    entity: null

推理配置示例

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet50"
    pretrained: false
  num_classes: 2
  checkpoint_path: "outputs/checkpoints/best_model.ckpt"

# Inference settings
inference:
  batch_size: 1
  device: "cuda:0"  # 或 "cpu"
  return_probabilities: true
  class_names: ["class0", "class1"]
  confidence_threshold: 0.5

# Preprocessing
preprocessing:
  image_size: [224, 224]
  normalize: true
  mean: [0.485, 0.456, 0.406]
  std: [0.229, 0.224, 0.225]

# Output settings
output:
  save_predictions: true
  include_probabilities: true
  format: "json"  # 或 "csv"

数据格式

文件夹结构

data/
├── classification/
│   ├── train/
│   │   ├── class1/
│   │   │   ├── image1.png
│   │   │   └── image2.png
│   │   └── class2/
│   │       ├── image3.png
│   │       └── image4.png
│   ├── val/
│   │   ├── class1/
│   │   └── class2/
│   └── test/
│       ├── class1/
│       └── class2/

支持的模型

  • ResNet系列: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
  • DenseNet系列: DenseNet121, DenseNet161, DenseNet169, DenseNet201
  • EfficientNet系列: EfficientNet-B0 到 EfficientNet-B7
  • Vision Transformer: ViT-Base, ViT-Large
  • ConvNeXt: ConvNeXt-Tiny, ConvNeXt-Small, ConvNeXt-Base
  • Medical专用: MedNet, RadImageNet预训练模型

许可证

本项目基于 MIT 许可证开源。

贡献

欢迎贡献代码!请查看 CONTRIBUTING.md 了解详情。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

medvision_classification-0.2.9.tar.gz (49.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

medvision_classification-0.2.9-py3-none-any.whl (64.9 kB view details)

Uploaded Python 3

File details

Details for the file medvision_classification-0.2.9.tar.gz.

File metadata

  • Download URL: medvision_classification-0.2.9.tar.gz
  • Upload date:
  • Size: 49.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for medvision_classification-0.2.9.tar.gz
Algorithm Hash digest
SHA256 3a7d37c6d6bb3364e149e1dfd126f143cd0fb84326a0450351d0cf68bb796825
MD5 7db3440471287e890b5ae31c42ad08a2
BLAKE2b-256 c3346ca7adc38747678ccd9d95cfe165409f8150a50fbfc7e797a77ab8f69e93

See more details on using hashes here.

File details

Details for the file medvision_classification-0.2.9-py3-none-any.whl.

File metadata

File hashes

Hashes for medvision_classification-0.2.9-py3-none-any.whl
Algorithm Hash digest
SHA256 97db30a1fbe821e1c283aa06620e2918933072560a11c49ac99d1845a4bfe696
MD5 ef54684c9445132581236ae6c76736e6
BLAKE2b-256 6b5186b593f2f3527a1b30224b0d5b38a56d5946bf8c195ee415d7797decd4b3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page