A comprehensive PyTorch Lightning framework for medical image classification with support for 2D/3D images
Project description
MedVision-Classification
MedVision-Classification 是一个基于 PyTorch Lightning 的医学影像分类框架,提供了训练和推理的简单接口。
特点
- 基于 PyTorch Lightning 的高级接口
- 支持常见的医学影像格式(NIfTI、DICOM 等)
- 内置多种分类模型架构(ResNet、DenseNet、EfficientNet 等)
- 灵活的数据加载和预处理管道
- 模块化设计,易于扩展
- 命令行界面用于训练和推理
- 支持二分类和多分类任务
安装
系统要求
- Python 3.8+
- PyTorch 2.0+
- CUDA (可选,用于GPU加速)
基本安装
最简单的安装方式:
pip install -e .
从源码安装
git clone https://github.com/yourusername/medvision-classification.git
cd medvision-classification
pip install -e .
使用requirements文件
# 基本环境
pip install -r requirements.txt
# 开发环境
pip install -r requirements-dev.txt
使用conda环境
推荐使用 conda 创建独立的虚拟环境:
# 创建并激活环境
conda env create -f environment.yml
conda activate medvision-cls
# 安装项目本身
pip install -e .
快速入门
训练2D模型
medvision-cls train configs/train_config.yml
训练3D模型
medvision-cls train configs/train_3d_resnet_config.yml
测试模型
medvision-cls test configs/test_config.yml
推理
MedVision-cls predict configs/inference_config.yml --input /path/to/image --output /path/to/output
配置格式
2D分类训练配置示例
# 2D ResNet Training Configuration
seed: 42
task_dim: 2d
# Model configuration
model:
type: "classification"
network:
name: "resnet50"
pretrained: true
num_classes: 4
# Metrics to compute
metrics:
accuracy:
type: "accuracy"
f1:
type: "f1"
precision:
type: "precision"
recall:
type: "recall"
auc:
type: "auroc"
# Loss configuration
loss:
type: "cross_entropy"
weight: null
label_smoothing: 0.0
# Optimizer configuration
optimizer:
type: "adam"
lr: 0.001
weight_decay: 0.0001
# Scheduler configuration
scheduler:
type: "cosine"
T_max: 100
eta_min: 0.00001
# Data configuration
data:
type: "medical"
batch_size: 4
num_workers: 4
data_dir: "data/classification"
image_format: "*.png"
# Transform configuration for 2D data
transforms:
image_size: [224, 224]
normalize: true
augment: true
# Data split configuration
train_val_split: [0.8, 0.2]
seed: 42
# Training configuration
training:
max_epochs: 10
accelerator: "gpu"
devices: [0,1,2,3] # Multi-GPU training
precision: 16
save_metrics: true
# Callbacks
model_checkpoint:
monitor: "val/accuracy"
mode: "max"
save_top_k: 3
filename: "epoch_{epoch:02d}-val_acc_{val/accuracy:.3f}"
# Validation configuration
validation:
check_val_every_n_epoch: 1
# Class names
class_names:
- "Class_0"
- "Class_1"
# Output paths
outputs:
output_dir: "outputs"
# Logging
logging:
log_every_n_steps: 10
wandb:
enabled: false
project: "medvision-2d-classification"
entity: null
3D分类训练配置示例
# 3D ResNet Training Configuration
seed: 42
task_dim: 3D
# Model configuration
model:
type: "classification"
network:
name: "resnet3d_18" # Options: resnet3d_18, resnet3d_34, resnet3d_50
pretrained: false # No pretrained weights for 3D models
in_channels: 3 # Input channels (typically 1 for medical images)
dropout: 0.1
num_classes: 2
# Metrics to compute
metrics:
accuracy:
type: "accuracy"
f1:
type: "f1"
precision:
type: "precision"
recall:
type: "recall"
auc:
type: "auroc"
# Loss configuration
loss:
type: "cross_entropy"
weight: null
label_smoothing: 0.0
# Optimizer configuration
optimizer:
type: "adam"
lr: 0.001
weight_decay: 0.0001
# Scheduler configuration
scheduler:
type: "cosine"
T_max: 100
eta_min: 0.00001
# Data configuration
data:
type: "medical"
batch_size: 4 # Smaller batch size for 3D data
num_workers: 4
data_dir: "data/3D"
image_format: "*.nii.gz" # 3D medical image format
# Transform configuration for 3D data
transforms:
image_size: [64, 64, 64] # [D, H, W] for 3D volumes
normalize: true
augment: true
# Data split configuration
train_val_split: [0.8, 0.2]
seed: 42
# Training configuration
training:
max_epochs: 5
accelerator: "gpu"
devices: 1 # Single GPU for 3D (memory intensive)
precision: 16 # Use mixed precision to save memory
# Callbacks
early_stopping:
monitor: "val/loss"
patience: 10
mode: "min"
model_checkpoint:
monitor: "val/accuracy"
mode: "max"
save_top_k: 3
filename: "epoch_{epoch:02d}-val_acc_{val/accuracy:.3f}"
# Validation configuration
validation:
check_val_every_n_epoch: 1
# Output paths
outputs:
output_dir: "outputs"
# Logging
logging:
log_every_n_steps: 10
wandb:
enabled: false
project: "medvision-3d-classification"
entity: null
推理配置示例
# Model configuration
model:
type: "classification"
network:
name: "resnet50"
pretrained: false
num_classes: 2
checkpoint_path: "outputs/checkpoints/best_model.ckpt"
# Inference settings
inference:
batch_size: 1
device: "cuda:0" # 或 "cpu"
return_probabilities: true
class_names: ["class0", "class1"]
confidence_threshold: 0.5
# Preprocessing
preprocessing:
image_size: [224, 224]
normalize: true
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
# Output settings
output:
save_predictions: true
include_probabilities: true
format: "json" # 或 "csv"
数据格式
文件夹结构
data/
├── classification/
│ ├── train/
│ │ ├── class1/
│ │ │ ├── image1.png
│ │ │ └── image2.png
│ │ └── class2/
│ │ ├── image3.png
│ │ └── image4.png
│ ├── val/
│ │ ├── class1/
│ │ └── class2/
│ └── test/
│ ├── class1/
│ └── class2/
支持的模型
- ResNet系列: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
- DenseNet系列: DenseNet121, DenseNet161, DenseNet169, DenseNet201
- EfficientNet系列: EfficientNet-B0 到 EfficientNet-B7
- Vision Transformer: ViT-Base, ViT-Large
- ConvNeXt: ConvNeXt-Tiny, ConvNeXt-Small, ConvNeXt-Base
- Medical专用: MedNet, RadImageNet预训练模型
许可证
本项目基于 MIT 许可证开源。
贡献
欢迎贡献代码!请查看 CONTRIBUTING.md 了解详情。
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file medvision_classification-0.2.6.tar.gz.
File metadata
- Download URL: medvision_classification-0.2.6.tar.gz
- Upload date:
- Size: 51.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
43a428bab502aafe66fd2a94b10a56395cbe9d5f0074ca28f54ec9433825c372
|
|
| MD5 |
697260972b3391c1e7cb2fea4061d252
|
|
| BLAKE2b-256 |
3dfa3499aedc81aca45ec760f5fe300d9d21be074fbe8041ae31023b28a1590b
|
File details
Details for the file medvision_classification-0.2.6-py3-none-any.whl.
File metadata
- Download URL: medvision_classification-0.2.6-py3-none-any.whl
- Upload date:
- Size: 66.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d44d8b9213ecffa8d88b06acbb80da2d36cf8211694991a8fdfdc7151c07c07c
|
|
| MD5 |
de60bcd83d624355118c153aaca9cfb2
|
|
| BLAKE2b-256 |
9c61745242741a395a9f4f860189f2825209cbdb327dcaea802adacc0de6f701
|