Typer CLI package for the Vit medical multi-axis Vision Transformer project.
Project description
ViT医学影像分类项目
基于PyTorch框架实现的Vision in Transformer (ViT) 模型,专门用于医学影像分类任务。
项目概述
本项目采用transformer架构对16x368尺寸的医学图像进行分析,通过将大图像裁剪为23个16x16的图像块,预测样本为阳性的概率。模型集成了Golden Gate Rotary Position Embedding技术,用于更好地捕捉图像的空间位置信息。
多轴主线(Phase 1-6)
当前项目已经补充多轴 ViT 主线,并保留原有单轴架构可用:
- 输入端:
x/y/z三个 axis 各输入4张368x16PNG - 训练端:对每位患者穷举
4 x 4 x 4 = 64个跨轴组合样本,仍使用监督分类训练 encoder - 验证/预测端:先提取每个 axis 的特征,再执行 per-axis KNN,随后按 axis 权重投票得到组合概率
- 患者级输出:将 64 个组合概率求均值,得到
patient_predict_proba,并按predict_acc_prob阈值输出 0/1 预测
关键配置
configs/base/multi_axis_config.yaml- 控制
axis_data_paths、active_axes、voting_weights、images_per_axis、knn_k
- 控制
configs/base/monitoring_config.yaml- 控制早停、最佳模型选择、
predict_threshold、verbose
- 控制早停、最佳模型选择、
训练产物
多轴训练至少会生成以下文件:
checkpoints/best_model.pthcheckpoints/latest_model.pthresults/train_state.jsonresults/feature_store_latest.npzresults/feature_store_best.npzlogs/*.log
train_state.json 用于从 latest_model.pth 恢复训练;恢复粒度为“上一完整 epoch 结束后”的状态。
终端输出规则
- 默认不加
--verbose- 训练/预测过程中终端不输出实时信息
- 仅在流程完全结束后输出总结、总耗时和关键文件绝对路径
- 使用
--verbose- 训练仅显示一根总 tqdm
- 预测仅显示一根总 tqdm
- 详细调试信息仍只写日志文件
多轴训练示例
conda run -n Vit python train.py \
--config_path configs/base \
--data_path dataset/data \
--label_path dataset/label.csv \
--experiment_name multi_axis_exp
恢复训练:
conda run -n Vit python train.py \
--config_path configs/base \
--label_path dataset/label.csv \
--experiment_name multi_axis_exp \
--resume
多轴预测:
conda run -n Vit python predict.py \
--model experiments/multi_axis_exp/checkpoints/best_model.pth \
--data dataset/predict_root \
--output experiments/multi_axis_exp/predictions \
--label dataset/label.csv
Typer CLI
项目现在同时提供 Typer CLI 主命令 vit,并保留 train.py / predict.py 旧入口。
vit get --path /tmp/vit_templates
vit train --config configs/base --data dataset/data --label dataset/label.csv --verbose
vit pred --model experiments/multi_axis_exp/checkpoints/best_model.pth --data dataset/predict_root --output output --label dataset/label.csv --thresholds 0.5
vit --version
vit pred --data 支持两种输入布局:
- 根目录内固定包含
data-x/、data-y/、data-z/ - 平铺目录,文件名必须遵循
{patient_id}_{Z}_{Y}_{X}.png
打包与构建
conda run -n Vit pip install -e .
conda run -n Vit vit --help
conda run -n Vit python -m build
模型介绍
MedicalViT 网络架构
本项目采用基于Golden Gate Rotary Position Embedding的Vision Transformer (ViT) 架构,专门针对16x368尺寸的医学影像分类任务进行优化。
graph TD
A[输入图像] --> B["16x368x1 医学图像"]
B --> C[补丁嵌入层 Patch Embedding]
subgraph "补丁嵌入层"
C --> D["图像重组: 16x368 → 23x16x16"]
D --> E["展平操作: 23x256维向量"]
E --> F["线性投影: 256 → 512维"]
F --> G[层归一化]
end
G --> H[位置编码]
subgraph "Golden Gate Rotary 位置编码"
H --> I["可学习位置编码: 1x23x512"]
I --> J["位置坐标生成: 2D坐标"]
J --> K[旋转位置编码: RoPE]
end
K --> L["Transformer编码器 x6层"]
subgraph "单个Transformer层"
L --> M[多头自注意力机制]
M --> N["残差连接 + 层归一化"]
N --> O[前馈神经网络]
O --> P["残差连接 + 层归一化"]
end
subgraph "多头自注意力机制"
M --> Q["查询/键/值投影: 512→8x64"]
Q --> R[旋转位置编码应用]
R --> S[注意力权重计算]
S --> T[注意力输出聚合]
T --> U["输出投影: 8x64→512"]
end
subgraph "前馈神经网络"
O --> V["线性层: 512→2048"]
V --> W[GELU激活]
W --> X["Dropout: 0.1"]
X --> Y["线性层: 2048→512"]
Y --> Z["Dropout: 0.1"]
end
P --> AA[全局平均池化]
AA --> BB[MLP分类头]
subgraph "MLP分类头"
BB --> CC["层归一化: 512维"]
CC --> DD["线性分类器: 512→1"]
DD --> EE[Sigmoid激活]
end
EE --> FF[输出: 阳性概率]
style A fill:#e1f5fe
style FF fill:#c8e6c9
style L fill:#fff3e0
style M fill:#f3e5f5
style O fill:#e8f5e8
核心技术特点
1. Golden Gate Rotary Position Embedding
- N维旋转位置编码: 采用Golden Gate方法实现2D空间位置编码
- 方向向量生成: 使用低差异序列生成均匀分布的方向向量
- 频率插值: 支持多尺度频率的位置感知
- 计算效率: 避免传统位置编码的平方复杂度
2. 分层注意力机制
- 多头注意力: 8个注意力头,每个头维度64
- 查询/键/值分离: Q和K共享投影,V独立投影
- 残差连接: 保持梯度流动,缓解消失问题
- 层归一化: 稳定训练过程
3. 补丁嵌入策略
- 自适应补丁分割: 16×368图像自动分割为23个16×16补丁
- 线性投影: 将256维补丁向量映射到512维特征空间
- 双层归一化: 补丁级和特征级双重归一化
4. 分类头设计
- 全局池化: 汇聚所有补丁特征
- 线性分类器: 简单高效的单层分类
- Sigmoid输出: 输出阳性概率
模型参数配置
| 配置项 | 数值 | 说明 |
|---|---|---|
| 输入尺寸 | 16x368x1 | 医学影像标准尺寸 |
| 补丁尺寸 | 16x16 | 补丁分割大小 |
| 补丁数量 | 23 | 16x368 → 23x16x16 |
| 模型维度 | 512 | Transformer特征维度 |
| 注意力头数 | 8 | 多头注意力配置 |
| 注意力维度 | 64 | 每个注意力头的维度 |
| Transformer层数 | 6 | 编码器深度 |
| MLP隐藏维度 | 2048 | 前馈网络隐藏层 |
| Dropout比率 | 0.1 | 正则化参数 |
| 分类维度 | 1 | 二分类输出 |
数据流程
- 输入处理: 16x368医学图像 → 23个16x16补丁
- 特征提取: 每个补丁线性投影到512维特征空间
- 位置编码: 应用可学习和旋转位置编码
- Transformer处理: 6层编码器提取深层特征
- 特征汇聚: 全局平均池化整合补丁信息
- 分类预测: 单层线性分类器输出阳性概率
创新点
- 医学影像专用: 针对16x368尺寸优化的ViT架构
- 高效位置编码: Golden Gate RoPE提供更强的空间位置感知
- 轻量化设计: 合理的参数量配置,适合医疗场景部署
- 稳定训练: 多重归一化和残差连接确保训练稳定性
环境要求
- Python 3.11+
- PyTorch 2.6.0+
- CUDA (可选,用于GPU加速)
安装依赖
pip install -r requirements.txt
项目结构
VIT/
├── src/ # 源代码
│ ├── model/vit_model.py # ViT模型实现
│ ├── model/vit_nd_rotary.py # 外接入的包
│ ├── data/preprocessing.py # 数据预处理
│ ├── data/dataset.py # 数据集类
│ ├── data/enhanced_splitting.py # 加强数据集划分类
│ ├── data/improved_dataloader.py# 增强数据集加载类
│ ├── data/memory_preloader.py # 数据集预加载类
│ ├── data/advanced_splitting.py # 基础数据集划分类
│ ├── utils/config.py # 配置管理
│ ├── utils/logger.py # 日志管理
│ ├── utils/metrics.py # 评估指标
│ ├── utils/monitor.py # 监控可视化
│ └── training/trainer.py # 训练器
├── configs/ # 配置文件
│ ├── data_config.yaml # 数据配置
│ ├── experiment_config.yaml # 实验配置
│ ├── model_config.yaml # 模型配置
│ └── training_config.yaml # 训练配置
├── train.py # 主训练脚本
├── predict.py # 预测脚本
├── experiments/ # CV多折交叉验证结果
├── dataset/ # 数据集
│ ├── data/ # 图像数据
│ └── label.csv # 标签文件
├── checkpoints/ # 保存预训练的最佳权重的训练结果
├── requirements_pip.txt # 依赖列表
└── README.md # 项目说明
训练前测试流程
1. 环境验证
验证PyTorch和CUDA环境:
python -c "import torch; print('PyTorch版本:', torch.__version__); print('CUDA可用:', torch.cuda.is_available())"
2. 数据一致性验证
验证数据集的完整性:
python -c "from src.data.preprocessing import validate_data_consistency; validate_data_consistency('dataset/data', 'dataset/label.csv')"
3. 模型测试
测试ViT模型的结构和前向传播:
python src/model/vit_model.py
4. 数据加载器测试
测试数据预处理和加载流程:
python src/data/dataset.py
5. 配置系统测试
测试配置文件的加载和保存:
python src/utils/config.py
6.训练功能测试(cv/stratified)
6a. 分层分割训练测试(--mode stratified)
使用分层分割进行训练:
python train.py --batch_size 32 --epochs 5 --experiment_name stratified_test --mode stratified --device cuda
6b. 交叉验证训练获取最佳训练参数测试(默认模式)
使用交叉验证进行训练(默认模式):
python train.py --batch_size 32 --epochs 3 --experiment_name cv_test --cv_folds 3 --device cuda
7. 预测功能测试
7a. 图像预测模式(推荐)(JSON + CSV输出)
测试单个图像预测(不使用标签文件):
python predict.py --checkpoint checkpoints/train_001/checkpoints/best_model.pth --init checkpoints/train_001/results/train_001_initialization.pth --mode image --input dataset/data/MR201201030297_015_0_0.png --output predictions/predictions.json --label_csv ""
测试单个图像预测(使用标签文件验证):
python predict.py --checkpoint checkpoints/train_001/checkpoints/best_model.pth --init checkpoints/train_001/results/train_001_initialization.pth --mode image --input dataset/data/MR201201030297_015_0_0.png --output predictions/predictions.json --label_csv dataset/label.csv
测试批量图像预测(不使用标签文件):
python predict.py --checkpoint checkpoints/train_001/checkpoints/best_model.pth --init checkpoints/train_001/results/train_001_initialization.pth --mode image --input dataset/data/ --output predictions/batch_predictions.json --label_csv ""
测试批量图像预测(使用标签文件验证):
python predict.py --checkpoint checkpoints/train_001/checkpoints/best_model.pth --init checkpoints/train_001/results/train_001_initialization.pth --mode image --input dataset/data/ --output predictions/batch_predictions.json --label_csv dataset/label.csv
禁用CSV输出(仅生成JSON文件):
python predict.py --checkpoint checkpoints/train_001/checkpoints/best_model.pth --init checkpoints/train_001/results/train_001_initialization.pth --mode image --input dataset/data/ --output predictions/batch_predictions.json --no_csv
预测模式说明:
- image: 单个图像或批量图像预测(最常用)
预测参数说明:
--checkpoint: 模型检查点路径(必需)--mode: 预测模式,推荐使用image模式--init: 模型初始化参数文件路径(推荐,确保预测一致性)--config_path: 配置文件路径(默认:configs/)--input: 输入图像路径或目录(必需)--output: 预测结果输出文件(默认:predictions.json)--threshold: 分类阈值(默认:0.5)--save_csv: 启用CSV格式输出(默认启用)--no_csv: 禁用CSV格式输出--label_csv: 标签文件路径,用于生成correct字段(默认:dataset/label.csv)
重要提示:
- 使用
--init参数加载训练时的初始化参数文件,确保预测结果的一致性和可重现性 - 初始化参数文件通常位于
{实验目录}/results/{实验名称}_initialization.pth - CSV输出功能默认启用,会生成包含
image_id, probability, prediction, correct字段的CSV文件 - 如果不需要CSV文件,可以使用
--no_csv参数禁用CSV输出 - CSV文件便于在Excel等工具中查看预测结果和统计分析
--label_csv参数控制是否包含correct字段:- 如果标签文件存在且包含预测样本,CSV会包含correct字段并显示True/False/NaN
- 如果标签文件不存在或预测样本不在标签文件中,correct字段会显示为NaN
- 如果不提供标签文件或使用空字符串,CSV不会包含correct字段
- 支持单个图像和批量图像预测,批量预测时显示实时进度条
完整训练流程
交叉验证训练(推荐,默认模式)
python train.py --mode cv --experiment_name cv_001
分层分割训练
python train.py --mode stratified --experiment_name train_001
带检查点恢复的训练
python train.py --resume --epochs 100 --experiment_name vit_experiment_v1_resume
指定设备的训练
python train.py --device cuda --batch_size 64 --epochs 100 --experiment_name vit_experiment_gpu --mode stratified
训练后的预测示例
预测单个图像:
python predict.py --checkpoint experiments/train_001/checkpoints/best_model.pth --init experiments/train_001/results/train_001_initialization.pth --mode image --input dataset/data/MR201201030297_015_0_0.png --label_csv dataset/label.csv
批量预测整个数据集:
python predict.py --checkpoint experiments/train_001/checkpoints/best_model.pth --init experiments/train_001/results/train_001_initialization.pth --mode image --input dataset/data/ --output predictions/batch_predictions.json --label_csv dataset/label.csv
预测指定训练的验证集:
python predict.py --checkpoint "checkpoints\train_001\checkpoints\best_model.pth" --init "checkpoints\train_001\results\train_001_initialization.pth" --mode val_predict --train_split_csv "checkpoints\train_001\results\train_split.csv" --data_path "dataset\data" --label_path "dataset\label.csv" --output_dir "predictions/val_train_001_predictions"
配置说明
模型配置 (configs/model_config.yaml)
input_shape: 输入图像尺寸 [16, 368]patch_size: 补丁尺寸 [16, 16]dim: Transformer维度depth: Transformer层数heads: 注意力头数
数据配置 (configs/data_config.yaml)
data_path: 图像数据路径label_path: 标签文件路径batch_size: 批次大小num_workers: 数据加载器工作进程数(启用预加载时自动设为0)pin_memory: 是否锁定内存页enable_memory_preload: 是否启用内存预加载(默认true)preload_batch_size: 分批预加载数量(null表示全部加载)preload_verbose: 是否显示预加载详细信息
训练配置 (configs/training_config.yaml)
epochs: 训练轮数learning_rate: 学习率optimizer: 优化器类型scheduler: 学习率调度器train_split.method: 数据分割方法 ("stratified", "random", "cv"),默认为"cv"train_split.cv_folds: 交叉验证折数(仅CV模式)train_split.train_ratio: 训练集比例(非CV模式)train_split.val_ratio: 验证集比例(非CV模式)train_split.test_ratio: 测试集比例(非CV模式)train_split.random_seed: 随机种子export_split_info: 是否导出数据划分信息split_filename: 数据划分CSV文件名
注意:交叉验证模式下,自动将test_ratio合并到val_ratio中,充分利用数据
实验结果
训练完成后,结果将保存在以下位置:
logs/: 训练日志和TensorBoard文件checkpoints/: 模型检查点和初始化参数文件results/: 数据划分CSV、评估报告和可视化结果
每个实验都有独立的文件夹,包含完整的实验数据和结果。
性能指标
模型支持以下评估指标Tensorboard记录:
- 准确率 (Accuracy)
- 精确率 (Precision)
- 召回率 (Recall)
- F1分数 (F1-Score)
- AUC值
- 混淆矩阵
- ROC曲线
- PR曲线
故障排除
-
CUDA内存不足
- 减小batch_size
- 启用混合精度训练
- 使用梯度累积
-
数据加载失败
- 检查数据路径是否正确
- 验证图像文件格式
- 确认标签文件格式
-
模型不收敛
- 使用CV模式训练获得合适的学习率
- 检查数据预处理
- 验证损失函数设置
-
内存预加载相关问题
- 内存不足: 减小preload_batch_size或禁用enable_memory_preload
- 预加载失败: 系统会自动回退到传统数据加载模式
- 性能提升不明显: 检查数据集大小,小数据集效果更明显
- 缓存命中率低: 确保enable_memory_preload设置为true
-
训练速度仍然较慢
- 确认内存预加载已启用(检查训练日志中的预加载信息)
- 减少num_workers(启用预加载时会自动设为0)
- 检查硬盘性能,传统模式可能受限于磁盘IO,推荐enable_memory_preload设置为true
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file btvit_cli-0.1.1.tar.gz.
File metadata
- Download URL: btvit_cli-0.1.1.tar.gz
- Upload date:
- Size: 121.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
89c6116df8785cabb83990f65cec0e71fcd4e99df92bdcad03790ba143a1036f
|
|
| MD5 |
e8bab7b148232056e0b856a3cdbe3634
|
|
| BLAKE2b-256 |
03e063ab4aab2e1eaae6509e00cb001a80e19838f15d7ba55b8d5652a1b1709c
|
File details
Details for the file btvit_cli-0.1.1-py3-none-any.whl.
File metadata
- Download URL: btvit_cli-0.1.1-py3-none-any.whl
- Upload date:
- Size: 112.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fe8a7a46df44d6997d3f32df724dff503c2dd40da42f5087dc23779abed30191
|
|
| MD5 |
ea20ea4203c6c67d5006b99ded2d5b48
|
|
| BLAKE2b-256 |
284a920f82074fa54b42dbb0c342018e28ca1bc08e90038c9813b5c781b08c80
|