Skip to main content

A comprehensive PyTorch Lightning framework for medical image classification with support for 2D/3D images

Project description

MedVision-Classification

MedVision-Classification 是一个基于 PyTorch Lightning 的医学影像分类框架,提供了训练和推理的简单接口。

特点

  • 基于 PyTorch Lightning 的高级接口
  • 支持常见的医学影像格式(NIfTI、DICOM 等)
  • 内置多种分类模型架构(ResNet、DenseNet、EfficientNet 等)
  • 灵活的数据加载和预处理管道
  • 模块化设计,易于扩展
  • 命令行界面用于训练和推理
  • 支持二分类和多分类任务

安装

系统要求

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA (可选,用于GPU加速)

基本安装

最简单的安装方式:

pip install -e .

从源码安装

git clone https://github.com/yourusername/medvision-classification.git
cd medvision-classification
pip install -e .

使用requirements文件

# 基本环境
pip install -r requirements.txt

# 开发环境
pip install -r requirements-dev.txt

使用conda环境

推荐使用 conda 创建独立的虚拟环境:

# 创建并激活环境
conda env create -f environment.yml
conda activate medvision-cls

# 安装项目本身
pip install -e .

快速入门

训练2D模型

medvision-cls train configs/train_config.yml

训练3D模型

medvision-cls train configs/train_3d_resnet_config.yml

测试模型

medvision-cls test configs/test_config.yml

推理

MedVision-cls predict configs/inference_config.yml --input /path/to/image --output /path/to/output

配置格式

2D分类训练配置示例

# 2D ResNet Training Configuration
seed: 42

task_dim: 2d

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet50"
    pretrained: true
  num_classes: 4

  # Metrics to compute
  metrics:
    accuracy:
      type: "accuracy"
    f1:
      type: "f1"
    precision:
      type: "precision"
    recall:
      type: "recall"
    auc:
      type: "auroc"
        
  # Loss configuration
  loss:
    type: "cross_entropy"
    weight: null
    label_smoothing: 0.0
  
  # Optimizer configuration
  optimizer:
    type: "adam"
    lr: 0.001
    weight_decay: 0.0001
  
  # Scheduler configuration
  scheduler:
    type: "cosine"
    T_max: 100
    eta_min: 0.00001

# Data configuration
data:
  type: "medical"
  batch_size: 4
  num_workers: 4
  data_dir: "data/classification"
  image_format: "*.png"
  
  # Transform configuration for 2D data
  transforms:
    image_size: [224, 224]
    normalize: true
    augment: true
    
  # Data split configuration
  train_val_split: [0.8, 0.2]
  seed: 42

# Training configuration
training:
  max_epochs: 10
  accelerator: "gpu"
  devices: [0,1,2,3]  # Multi-GPU training
  precision: 16
  save_metrics: true
  
  # Callbacks
  model_checkpoint:
    monitor: "val/accuracy"
    mode: "max"
    save_top_k: 3
    filename: "epoch_{epoch:02d}-val_acc_{val/accuracy:.3f}"

# Validation configuration
validation:
  check_val_every_n_epoch: 1

# Class names
class_names:
  - "Class_0"
  - "Class_1"

# Output paths
outputs:
  output_dir: "outputs"

# Logging
logging:
  log_every_n_steps: 10
  wandb:
    enabled: false
    project: "medvision-2d-classification"
    entity: null

3D分类训练配置示例

# 3D ResNet Training Configuration
seed: 42

task_dim: 3D

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet3d_18"  # Options: resnet3d_18, resnet3d_34, resnet3d_50
    pretrained: false    # No pretrained weights for 3D models
    in_channels: 3       # Input channels (typically 1 for medical images)
    dropout: 0.1
  num_classes: 2

  # Metrics to compute
  metrics:
    accuracy:
      type: "accuracy"
    f1:
      type: "f1"
    precision:
      type: "precision"
    recall:
      type: "recall"
    auc:
      type: "auroc"

  # Loss configuration
  loss:
    type: "cross_entropy"
    weight: null
    label_smoothing: 0.0
  
  # Optimizer configuration
  optimizer:
    type: "adam"
    lr: 0.001
    weight_decay: 0.0001
  
  # Scheduler configuration
  scheduler:
    type: "cosine"
    T_max: 100
    eta_min: 0.00001

# Data configuration
data:
  type: "medical"
  batch_size: 4         # Smaller batch size for 3D data
  num_workers: 4
  data_dir: "data/3D"
  image_format: "*.nii.gz"  # 3D medical image format
  
  # Transform configuration for 3D data
  transforms:
    image_size: [64, 64, 64]  # [D, H, W] for 3D volumes
    normalize: true
    augment: true
    
  # Data split configuration
  train_val_split: [0.8, 0.2]
  seed: 42

# Training configuration
training:
  max_epochs: 5
  accelerator: "gpu"
  devices: 1            # Single GPU for 3D (memory intensive)
  precision: 16         # Use mixed precision to save memory
  
  # Callbacks
  early_stopping:
    monitor: "val/loss"
    patience: 10
    mode: "min"
  
  model_checkpoint:
    monitor: "val/accuracy"
    mode: "max"
    save_top_k: 3
    filename: "epoch_{epoch:02d}-val_acc_{val/accuracy:.3f}"

# Validation configuration
validation:
  check_val_every_n_epoch: 1

# Output paths
outputs:
  output_dir: "outputs"

# Logging
logging:
  log_every_n_steps: 10
  wandb:
    enabled: false
    project: "medvision-3d-classification"
    entity: null

推理配置示例

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet50"
    pretrained: false
  num_classes: 2
  checkpoint_path: "outputs/checkpoints/best_model.ckpt"

# Inference settings
inference:
  batch_size: 1
  device: "cuda:0"  # 或 "cpu"
  return_probabilities: true
  class_names: ["class0", "class1"]
  confidence_threshold: 0.5

# Preprocessing
preprocessing:
  image_size: [224, 224]
  normalize: true
  mean: [0.485, 0.456, 0.406]
  std: [0.229, 0.224, 0.225]

# Output settings
output:
  save_predictions: true
  include_probabilities: true
  format: "json"  # 或 "csv"

数据格式

文件夹结构

data/
├── classification/
│   ├── train/
│   │   ├── class1/
│   │   │   ├── image1.png
│   │   │   └── image2.png
│   │   └── class2/
│   │       ├── image3.png
│   │       └── image4.png
│   ├── val/
│   │   ├── class1/
│   │   └── class2/
│   └── test/
│       ├── class1/
│       └── class2/

支持的模型

  • ResNet系列: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
  • DenseNet系列: DenseNet121, DenseNet161, DenseNet169, DenseNet201
  • EfficientNet系列: EfficientNet-B0 到 EfficientNet-B7
  • Vision Transformer: ViT-Base, ViT-Large
  • ConvNeXt: ConvNeXt-Tiny, ConvNeXt-Small, ConvNeXt-Base
  • Medical专用: MedNet, RadImageNet预训练模型

许可证

本项目基于 MIT 许可证开源。

贡献

欢迎贡献代码!请查看 CONTRIBUTING.md 了解详情。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

medvision_classification-0.2.8.tar.gz (49.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

medvision_classification-0.2.8-py3-none-any.whl (64.9 kB view details)

Uploaded Python 3

File details

Details for the file medvision_classification-0.2.8.tar.gz.

File metadata

  • Download URL: medvision_classification-0.2.8.tar.gz
  • Upload date:
  • Size: 49.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for medvision_classification-0.2.8.tar.gz
Algorithm Hash digest
SHA256 39d08dd38da5f878ad6ddebf8c245eb27c9c6276da2b96fd381d6053b714b379
MD5 9d84586dfc16864da0dd6046aeb8e9f3
BLAKE2b-256 d5203b965b821e20242e80dab20ad95434f458e920d0858913e1f51a1b83fc1b

See more details on using hashes here.

File details

Details for the file medvision_classification-0.2.8-py3-none-any.whl.

File metadata

File hashes

Hashes for medvision_classification-0.2.8-py3-none-any.whl
Algorithm Hash digest
SHA256 dc3de22a03cf66a0148a434d27816750b0e1b44eb33a24fe2cf6e51e501af13c
MD5 b4df0058e7dcc3b8f88bc4b6fe341063
BLAKE2b-256 b5c985c86d15be4f505945d86ef5c51f8862d58aa9706795d7d3825fef04b063

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page