Skip to main content

A comprehensive PyTorch Lightning framework for medical image classification with support for 2D/3D images

Project description

MedVision-Classification

MedVision-Classification 是一个基于 PyTorch Lightning 的医学影像分类框架,提供了训练和推理的简单接口。

特点

  • 基于 PyTorch Lightning 的高级接口
  • 支持常见的医学影像格式(NIfTI、DICOM 等)
  • 内置多种分类模型架构(ResNet、DenseNet、EfficientNet 等)
  • 灵活的数据加载和预处理管道
  • 模块化设计,易于扩展
  • 命令行界面用于训练和推理
  • 支持二分类和多分类任务

安装

系统要求

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA (可选,用于GPU加速)

基本安装

最简单的安装方式:

pip install -e .

从源码安装

git clone https://github.com/yourusername/medvision-classification.git
cd medvision-classification
pip install -e .

使用requirements文件

# 基本环境
pip install -r requirements.txt

# 开发环境
pip install -r requirements-dev.txt

使用conda环境

推荐使用 conda 创建独立的虚拟环境:

# 创建并激活环境
conda env create -f environment.yml
conda activate medvision-cls

# 安装项目本身
pip install -e .

快速入门

训练2D模型

medvision-cls train configs/train_config.yml

训练3D模型

medvision-cls train configs/train_3d_resnet_config.yml

测试模型

medvision-cls test configs/test_config.yml

推理

MedVision-cls predict configs/inference_config.yml --input /path/to/image --output /path/to/output

配置格式

2D分类训练配置示例

# 2D ResNet Training Configuration
seed: 42

task_dim: 2d

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet50"
    pretrained: true
  num_classes: 4

  # Metrics to compute
  metrics:
    accuracy:
      type: "accuracy"
    f1:
      type: "f1"
    precision:
      type: "precision"
    recall:
      type: "recall"
    auc:
      type: "auroc"
        
  # Loss configuration
  loss:
    type: "cross_entropy"
    weight: null
    label_smoothing: 0.0
  
  # Optimizer configuration
  optimizer:
    type: "adam"
    lr: 0.001
    weight_decay: 0.0001
  
  # Scheduler configuration
  scheduler:
    type: "cosine"
    T_max: 100
    eta_min: 0.00001

# Data configuration
data:
  type: "medical"
  batch_size: 4
  num_workers: 4
  data_dir: "data/classification"
  image_format: "*.png"
  
  # Transform configuration for 2D data
  transforms:
    image_size: [224, 224]
    normalize: true
    augment: true
    
  # Data split configuration
  train_val_split: [0.8, 0.2]
  seed: 42

# Training configuration
training:
  max_epochs: 10
  accelerator: "gpu"
  devices: [0,1,2,3]  # Multi-GPU training
  precision: 16
  save_metrics: true
  
  # Callbacks
  model_checkpoint:
    monitor: "val/accuracy"
    mode: "max"
    save_top_k: 3
    filename: "epoch_{epoch:02d}-val_acc_{val/accuracy:.3f}"

# Validation configuration
validation:
  check_val_every_n_epoch: 1

# Class names
class_names:
  - "Class_0"
  - "Class_1"

# Output paths
outputs:
  output_dir: "outputs"

# Logging
logging:
  log_every_n_steps: 10
  wandb:
    enabled: false
    project: "medvision-2d-classification"
    entity: null

3D分类训练配置示例

# 3D ResNet Training Configuration
seed: 42

task_dim: 3D

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet3d_18"  # Options: resnet3d_18, resnet3d_34, resnet3d_50
    pretrained: false    # No pretrained weights for 3D models
    in_channels: 3       # Input channels (typically 1 for medical images)
    dropout: 0.1
  num_classes: 2

  # Metrics to compute
  metrics:
    accuracy:
      type: "accuracy"
    f1:
      type: "f1"
    precision:
      type: "precision"
    recall:
      type: "recall"
    auc:
      type: "auroc"

  # Loss configuration
  loss:
    type: "cross_entropy"
    weight: null
    label_smoothing: 0.0
  
  # Optimizer configuration
  optimizer:
    type: "adam"
    lr: 0.001
    weight_decay: 0.0001
  
  # Scheduler configuration
  scheduler:
    type: "cosine"
    T_max: 100
    eta_min: 0.00001

# Data configuration
data:
  type: "medical"
  batch_size: 4         # Smaller batch size for 3D data
  num_workers: 4
  data_dir: "data/3D"
  image_format: "*.nii.gz"  # 3D medical image format
  
  # Transform configuration for 3D data
  transforms:
    image_size: [64, 64, 64]  # [D, H, W] for 3D volumes
    normalize: true
    augment: true
    
  # Data split configuration
  train_val_split: [0.8, 0.2]
  seed: 42

# Training configuration
training:
  max_epochs: 5
  accelerator: "gpu"
  devices: 1            # Single GPU for 3D (memory intensive)
  precision: 16         # Use mixed precision to save memory
  
  # Callbacks
  early_stopping:
    monitor: "val/loss"
    patience: 10
    mode: "min"
  
  model_checkpoint:
    monitor: "val/accuracy"
    mode: "max"
    save_top_k: 3
    filename: "epoch_{epoch:02d}-val_acc_{val/accuracy:.3f}"

# Validation configuration
validation:
  check_val_every_n_epoch: 1

# Output paths
outputs:
  output_dir: "outputs"

# Logging
logging:
  log_every_n_steps: 10
  wandb:
    enabled: false
    project: "medvision-3d-classification"
    entity: null

推理配置示例

# Model configuration
model:
  type: "classification"
  network:
    name: "resnet50"
    pretrained: false
  num_classes: 2
  checkpoint_path: "outputs/checkpoints/best_model.ckpt"

# Inference settings
inference:
  batch_size: 1
  device: "cuda:0"  # 或 "cpu"
  return_probabilities: true
  class_names: ["class0", "class1"]
  confidence_threshold: 0.5

# Preprocessing
preprocessing:
  image_size: [224, 224]
  normalize: true
  mean: [0.485, 0.456, 0.406]
  std: [0.229, 0.224, 0.225]

# Output settings
output:
  save_predictions: true
  include_probabilities: true
  format: "json"  # 或 "csv"

数据格式

文件夹结构

data/
├── classification/
│   ├── train/
│   │   ├── class1/
│   │   │   ├── image1.png
│   │   │   └── image2.png
│   │   └── class2/
│   │       ├── image3.png
│   │       └── image4.png
│   ├── val/
│   │   ├── class1/
│   │   └── class2/
│   └── test/
│       ├── class1/
│       └── class2/

支持的模型

  • ResNet系列: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
  • DenseNet系列: DenseNet121, DenseNet161, DenseNet169, DenseNet201
  • EfficientNet系列: EfficientNet-B0 到 EfficientNet-B7
  • Vision Transformer: ViT-Base, ViT-Large
  • ConvNeXt: ConvNeXt-Tiny, ConvNeXt-Small, ConvNeXt-Base
  • Medical专用: MedNet, RadImageNet预训练模型

许可证

本项目基于 MIT 许可证开源。

贡献

欢迎贡献代码!请查看 CONTRIBUTING.md 了解详情。

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

medvision_classification-0.2.7.tar.gz (49.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

medvision_classification-0.2.7-py3-none-any.whl (65.0 kB view details)

Uploaded Python 3

File details

Details for the file medvision_classification-0.2.7.tar.gz.

File metadata

  • Download URL: medvision_classification-0.2.7.tar.gz
  • Upload date:
  • Size: 49.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for medvision_classification-0.2.7.tar.gz
Algorithm Hash digest
SHA256 849a053b8ffc62e08b7ad9541023a5e3b8766185697d05751b75d213c7521dea
MD5 7c35a0ef47387e396e530e322d55475d
BLAKE2b-256 532a75909c5942dd097bd66cc806c787838224ea4f0403d23a5473891839922c

See more details on using hashes here.

File details

Details for the file medvision_classification-0.2.7-py3-none-any.whl.

File metadata

File hashes

Hashes for medvision_classification-0.2.7-py3-none-any.whl
Algorithm Hash digest
SHA256 a9e084531a355e0d47eecbeef01ef572ad26f9248d0d10e223f5264eb0d1c7ad
MD5 1001e15011549126a01a044e6c3cb215
BLAKE2b-256 e5fede8b584d86f6114a45e767148f87840137107c90b9cea11e7841336e205b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page