Skip to main content

Typer CLI package for the Vit medical multi-axis Vision Transformer project.

Project description

btvit-cli

基于 PyTorch 的多轴 Vision Transformer 医学影像分类 CLI 工具。

安装

pip install btvit-cli

安装后即可使用 vit 命令。确认安装:

vit --version

快速开始

1. 导出配置模板

首次使用前,导出完整注释版 YAML 配置模板到工作目录:

vit get --path ./configs

将在目标目录生成以下文件:

文件 用途
model_config.yaml 模型架构:dim, depth, heads, patch_size, RoPE 参数
data_config.yaml 数据加载:内存预加载、pin_memory
training_config.yaml 训练超参:lr, optimizer, scheduler, stratified 分割
experiment_config.yaml 实验目录命名、日志频率
multi_axis_config.yaml 多轴路径、active_axes、KNN 参数、投票权重
monitoring_config.yaml 早停、最佳模型选择、预测阈值、verbose
USAGE.md 完整 CLI 使用说明

2. 训练

vit train --config configs/base --data dataset/data --label dataset/label.csv --verbose

必需参数:

  • --config — YAML 配置目录路径,目录内应包含上述全部配置文件

可选参数:

参数 默认值 说明
--data dataset/data 训练数据目录
--label dataset/label.csv 标签 CSV(表头 patient_id,label
--experiment-name 配置文件值 实验名称覆盖项
--epochs 配置文件值 训练轮数覆盖项
--batch-size 配置文件值 批次大小覆盖项
--device 配置文件值 训练设备(cpu / cuda
--random-seed 配置文件值 随机种子覆盖项
--verbose False 显示实时 tqdm 进度条
--resume False latest_model.pth 恢复训练
--disable-memory-preload False 关闭默认启用的内存预加载
--preload-batch-size 配置文件值 预加载批次大小
--pre-model 预训练模型路径(迁移学习)
--freeze-layers 冻结层配置
--freeze-ratio 冻结比例(0-1)
--load-weights-only False 仅加载模型权重
--strict-load False 严格加载预训练权重
--train-split-csv 预分割 CSV 路径
--csv-split-seed CSV 重新划分随机种子
--use-csv-direct False 直接使用 CSV 分割,不重新划分

训练示例:

# 基础训练
vit train --config configs/base --data dataset/data --label dataset/label.csv

# 恢复训练
vit train --config configs/base --label dataset/label.csv --experiment-name my_exp --resume

# 迁移学习
vit train --config configs/base --data dataset/data --label dataset/label.csv \
  --pre-model experiments/pretrained/checkpoints/best_model.pth \
  --freeze-ratio 0.5 --verbose

3. 预测

vit pred --model experiments/my_exp/checkpoints/best_model.pth \
  --data dataset/predict_root \
  --output output \
  --label dataset/label.csv \
  --verbose

必需参数:

  • --model — 训练生成的 checkpoint 路径(内含 config 快照,无需额外 --config
  • --data — 预测数据目录

可选参数:

参数 默认值 说明
--output output 预测结果输出目录
--label 标签 CSV(表头支持 patient_id,label0,1
--thresholds checkpoint 值 覆盖 checkpoint 中的预测阈值
--verbose False 显示实时 tqdm 进度条

--data 支持两种输入布局:

  1. 根目录布局 — 目录内固定包含 data-x/data-y/data-z/ 子目录
  2. 平铺布局 — 单一目录内所有 PNG 文件,文件名必须遵循 {patient_id}_{Z}_{Y}_{X}.png

预测输出:

  • predictions.csv — 患者级预测结果(含 pred_prob、pred_result、label、pass 列)
  • prediction_summary.json — 汇总统计
  • confusion_matrix.png — 混淆矩阵图(提供 label 时)

4. 版本

vit --version
# 或
vit version

多轴 ViT 架构

本项目的核心是多轴 Vision Transformer,用于医学影像二分类(阳性/阴性)。

数据流

输入:每位患者提供 x/y/z 三个轴,每轴 4 张 16x368 灰度 PNG
  -> ImageProcessor: 裁剪为 23 个 16x16 patch,归一化
  -> MemoryPreloader: 缓存至 RAM
  -> MultiAxisDataLoader: 穷举 4x4x4=64 个跨轴组合样本
  -> MultiAxisMedicalViT: 共享编码器 + 有监督 BCE 训练
  -> 验证/预测:per-axis 特征提取 -> KNN -> 加权投票 -> 组合概率
  -> 患者级输出:64 个组合概率均值 -> patient_predict_proba

模型流水线

Input: (batch, 23, 16, 16)
  -> Patch Embedding: flatten + Linear(256, 512) + LayerNorm  ->  (batch, 23, 512)
  -> Learnable Positional Encoding                           ->  (batch, 23, 512)
  -> Golden Gate RoPE (2D rotary position embedding)         ->  (batch, 23, 512)
  -> Transformer Encoder x6 (8 heads, dim_head=64)           ->  (batch, 23, 512)
  -> Global Average Pooling                                   ->  (batch, 512)
  -> LayerNorm + Linear(512, 1) + Sigmoid                    ->  (batch, 1)

关键设计

  • 患者级分割:同一患者的影像始终在同一数据分片中(防止数据泄漏)
  • 内存预加载:默认启用,将全部图像预加载到 RAM;启用时 num_workers 自动设为 0
  • 类别加权 BCE:通过 pos_weight 处理正负样本不平衡
  • Checkpoint 内容:模型权重 + 优化器状态 + 调度器状态 + 最佳指标 + config 快照
  • 双检查点:训练产物始终包含 best_model.pthlatest_model.pth
  • 恢复训练:基于 results/train_state.json 从上一完整 epoch 后的状态恢复

训练产物

experiments/{experiment_name}/
├── checkpoints/
│   ├── best_model.pth       # 最佳模型(按 monitoring.model_selection.metric 选择)
│   └── latest_model.pth     # 最新模型(用于 resume)
├── results/
│   ├── train_state.json     # 恢复训练状态
│   ├── train_split.csv      # 数据划分记录
│   ├── feature_store_best.npz
│   ├── feature_store_latest.npz
│   └── final_results.json
└── logs/
    └── *.log

配置参考

所有 YAML 配置均可通过 CLI 参数覆盖。运行 vit get --path ./configs 导出完整注释版模板。

model_config.yaml

配置项 默认值 说明
dim 512 Transformer 嵌入维度
depth 6 Transformer 编码器层数
heads 8 注意力头数
dim_head 64 每头维度
mlp_dim 2048 MLP 隐藏维度
dropout 0.1 Transformer 内 dropout
patch_size [16, 16] Patch 尺寸
input_shape [16, 368] 输入图像尺寸

training_config.yaml

配置项 默认值 说明
epochs 100 训练轮数
learning_rate 0.001 学习率
optimizer AdamW 优化器
scheduler CosineAnnealingLR 学习率调度器
batch_size 32 批次大小
device cuda 训练设备
train_split.method stratified 数据分割方法
train_split.train_ratio 0.75 训练集比例
train_split.val_ratio 0.15 验证集比例
train_split.test_ratio 0.10 测试集比例

multi_axis_config.yaml

配置项 默认值 说明
enabled true 启用多轴路径
active_axes [x, y, z] 活跃轴列表
voting_weights {x: 0.3, y: 0.3, z: 0.4} 轴投票权重
images_per_axis 4 每轴 PNG 数量
knn_k 5 KNN 的 K 值
knn_metric cosine KNN 距离度量

monitoring_config.yaml

配置项 默认值 说明
early_stopping.enabled true 启用早停
early_stopping.metric val_loss 早停监控指标
early_stopping.patience 10 早停耐心轮数
model_selection.metric predict_accuracy 最佳模型选择指标
predict_threshold 0.5 患者级预测阈值
verbose true 终端详细输出

终端输出规则

  • 不加 --verbose:训练/预测过程中终端不输出实时信息,仅在流程结束后输出总结
  • --verbose:显示一根总 tqdm 进度条,详细调试信息仅写入日志文件

开发与构建

本地开发安装

conda activate Vit
pip install -e .
vit --help

构建发行包

conda run -n Vit python -m build

产物位于 dist/ 目录:btvit_cli-{version}-py3-none-any.whlbtvit-cli-{version}.tar.gz

发布到 PyPI

# 安装发布工具(首次)
pip install twine

# 构建
python -m build

# 校验
twine check dist/*

# 上传到 TestPyPI(可选预览)
twine upload --repository testpypi dist/*

# 上传到 PyPI
twine upload dist/*

发布前检查清单:

  1. 更新 src/__version__ 中的版本号
  2. 确认 pyproject.toml 中的 name = "btvit-cli"description 正确
  3. 运行 python -m build 构建无报错
  4. 运行 twine check dist/* 校验通过
  5. 如需测试,先上传到 TestPyPI 并验证安装

项目结构

src/
├── __init__.py              # __version__ = "0.1.2"
├── cli/
│   ├── main.py              # Typer 入口:train / pred / get / version
│   ├── train_cmd.py         # train 命令参数定义
│   ├── train_impl.py        # 训练核心逻辑(多轴 + 单轴分发)
│   ├── train_runner.py      # 训练执行包装
│   ├── predict_cmd.py       # pred 命令参数定义
│   ├── predict_impl.py      # 预测核心逻辑(checkpoint 驱动)
│   ├── predict_runner.py    # 预测执行包装
│   ├── script_loader.py     # 旧脚本兼容加载
│   └── utils.py             # YAML 模板、数据布局归一化
├── model/
│   ├── vit_model.py         # MedicalViT + MultiAxisMedicalViT
│   └── vit_nd_rotary.py     # Golden Gate RoPE
├── data/
│   ├── preprocessing.py     # ImageProcessor
│   ├── dataset.py           # 单轴/多轴 Dataset + Collator
│   ├── enhanced_splitting.py # 患者级分层/随机分割
│   ├── memory_preloader.py  # RAM 预加载
│   ├── improved_dataloader.py # DataLoaderManager
│   └── testset_export.py    # 测试集导出
├── training/
│   ├── trainer.py           # 单轴训练器
│   └── multi_axis_trainer.py # 多轴训练器
├── prediction/
│   ├── feature_store.py     # 特征存储
│   ├── knn_classifier.py    # KNN 分类
│   └── multi_axis_predictor.py # 多轴预测器
└── utils/
    ├── config.py            # YAML 配置系统
    ├── metrics.py           # Accuracy / Precision / Recall / F1 / AUC
    ├── logger.py            # 日志管理
    └── monitor.py           # TensorBoard

环境要求

  • Python >= 3.11
  • PyTorch >= 2.6.0
  • CUDA(可选,用于 GPU 加速)

许可证

Apache-2.0

作者

SuShuHeng

链接

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

btvit_cli-0.1.2.post2.tar.gz (119.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

btvit_cli-0.1.2.post2-py3-none-any.whl (113.0 kB view details)

Uploaded Python 3

File details

Details for the file btvit_cli-0.1.2.post2.tar.gz.

File metadata

  • Download URL: btvit_cli-0.1.2.post2.tar.gz
  • Upload date:
  • Size: 119.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for btvit_cli-0.1.2.post2.tar.gz
Algorithm Hash digest
SHA256 70c76ac53d3b27d707cabcdc4a917c0a8af505b892c810345c9185e62b2ca121
MD5 5954dc62b115e5892e98a613c5a541bc
BLAKE2b-256 98edea978a08f6e76382254115b30b12a966ca98b024fc13ee8834442b63221e

See more details on using hashes here.

File details

Details for the file btvit_cli-0.1.2.post2-py3-none-any.whl.

File metadata

  • Download URL: btvit_cli-0.1.2.post2-py3-none-any.whl
  • Upload date:
  • Size: 113.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for btvit_cli-0.1.2.post2-py3-none-any.whl
Algorithm Hash digest
SHA256 6f3a3baa030c1fffc570adf55ae1955cef5ab73707d629f4a9453bcf614f5fb6
MD5 b82b29c0146c3470f0667cd889bd8668
BLAKE2b-256 bba34414db02a1176d34dfe3e585218aa78095aaeb5f3ebf24d32435434552a5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page