Typer CLI package for the Vit medical multi-axis Vision Transformer project.
Project description
btvit-cli
基于 PyTorch 的多轴 Vision Transformer 医学影像分类 CLI 工具。
安装
pip install btvit-cli
安装后即可使用 vit 命令。确认安装:
vit --version
快速开始
1. 导出配置模板
首次使用前,导出完整注释版 YAML 配置模板到工作目录:
vit get --path ./configs
将在目标目录生成以下文件:
| 文件 | 用途 |
|---|---|
model_config.yaml |
模型架构:dim, depth, heads, patch_size, RoPE 参数 |
data_config.yaml |
数据加载:内存预加载、pin_memory |
training_config.yaml |
训练超参:lr, optimizer, scheduler, stratified 分割 |
experiment_config.yaml |
实验目录命名、日志频率 |
multi_axis_config.yaml |
多轴路径、active_axes、KNN 参数、投票权重 |
monitoring_config.yaml |
早停、最佳模型选择、预测阈值、verbose |
USAGE.md |
完整 CLI 使用说明 |
2. 训练
vit train --config configs/base --data dataset/data --label dataset/label.csv --verbose
必需参数:
--config— YAML 配置目录路径,目录内应包含上述全部配置文件
可选参数:
| 参数 | 默认值 | 说明 |
|---|---|---|
--data |
dataset/data |
训练数据目录 |
--label |
dataset/label.csv |
标签 CSV(表头 patient_id,label) |
--experiment-name |
配置文件值 | 实验名称覆盖项 |
--epochs |
配置文件值 | 训练轮数覆盖项 |
--batch-size |
配置文件值 | 批次大小覆盖项 |
--device |
配置文件值 | 训练设备(cpu / cuda) |
--random-seed |
配置文件值 | 随机种子覆盖项 |
--verbose |
False |
显示实时 tqdm 进度条 |
--resume |
False |
从 latest_model.pth 恢复训练 |
--disable-memory-preload |
False |
关闭默认启用的内存预加载 |
--preload-batch-size |
配置文件值 | 预加载批次大小 |
--pre-model |
— | 预训练模型路径(迁移学习) |
--freeze-layers |
— | 冻结层配置 |
--freeze-ratio |
— | 冻结比例(0-1) |
--load-weights-only |
False |
仅加载模型权重 |
--strict-load |
False |
严格加载预训练权重 |
--train-split-csv |
— | 预分割 CSV 路径 |
--csv-split-seed |
— | CSV 重新划分随机种子 |
--use-csv-direct |
False |
直接使用 CSV 分割,不重新划分 |
训练示例:
# 基础训练
vit train --config configs/base --data dataset/data --label dataset/label.csv
# 恢复训练
vit train --config configs/base --label dataset/label.csv --experiment-name my_exp --resume
# 迁移学习
vit train --config configs/base --data dataset/data --label dataset/label.csv \
--pre-model experiments/pretrained/checkpoints/best_model.pth \
--freeze-ratio 0.5 --verbose
3. 预测
vit pred --model experiments/my_exp/checkpoints/best_model.pth \
--data dataset/predict_root \
--output output \
--label dataset/label.csv \
--verbose
必需参数:
--model— 训练生成的 checkpoint 路径(内含 config 快照,无需额外--config)--data— 预测数据目录
可选参数:
| 参数 | 默认值 | 说明 |
|---|---|---|
--output |
output |
预测结果输出目录 |
--label |
— | 标签 CSV(表头支持 patient_id,label 或 0,1) |
--thresholds |
checkpoint 值 | 覆盖 checkpoint 中的预测阈值 |
--verbose |
False |
显示实时 tqdm 进度条 |
--data 支持两种输入布局:
- 根目录布局 — 目录内固定包含
data-x/、data-y/、data-z/子目录 - 平铺布局 — 单一目录内所有 PNG 文件,文件名必须遵循
{patient_id}_{Z}_{Y}_{X}.png
预测输出:
predictions.csv— 患者级预测结果(含 pred_prob、pred_result、label、pass 列)prediction_summary.json— 汇总统计confusion_matrix.png— 混淆矩阵图(提供 label 时)
4. 版本
vit --version
# 或
vit version
多轴 ViT 架构
本项目的核心是多轴 Vision Transformer,用于医学影像二分类(阳性/阴性)。
数据流
输入:每位患者提供 x/y/z 三个轴,每轴 4 张 16x368 灰度 PNG
-> ImageProcessor: 裁剪为 23 个 16x16 patch,归一化
-> MemoryPreloader: 缓存至 RAM
-> MultiAxisDataLoader: 穷举 4x4x4=64 个跨轴组合样本
-> MultiAxisMedicalViT: 共享编码器 + 有监督 BCE 训练
-> 验证/预测:per-axis 特征提取 -> KNN -> 加权投票 -> 组合概率
-> 患者级输出:64 个组合概率均值 -> patient_predict_proba
模型流水线
Input: (batch, 23, 16, 16)
-> Patch Embedding: flatten + Linear(256, 512) + LayerNorm -> (batch, 23, 512)
-> Learnable Positional Encoding -> (batch, 23, 512)
-> Golden Gate RoPE (2D rotary position embedding) -> (batch, 23, 512)
-> Transformer Encoder x6 (8 heads, dim_head=64) -> (batch, 23, 512)
-> Global Average Pooling -> (batch, 512)
-> LayerNorm + Linear(512, 1) + Sigmoid -> (batch, 1)
关键设计
- 患者级分割:同一患者的影像始终在同一数据分片中(防止数据泄漏)
- 内存预加载:默认启用,将全部图像预加载到 RAM;启用时
num_workers自动设为 0 - 类别加权 BCE:通过
pos_weight处理正负样本不平衡 - Checkpoint 内容:模型权重 + 优化器状态 + 调度器状态 + 最佳指标 + config 快照
- 双检查点:训练产物始终包含
best_model.pth和latest_model.pth - 恢复训练:基于
results/train_state.json从上一完整 epoch 后的状态恢复
训练产物
experiments/{experiment_name}/
├── checkpoints/
│ ├── best_model.pth # 最佳模型(按 monitoring.model_selection.metric 选择)
│ └── latest_model.pth # 最新模型(用于 resume)
├── results/
│ ├── train_state.json # 恢复训练状态
│ ├── train_split.csv # 数据划分记录
│ ├── feature_store_best.npz
│ ├── feature_store_latest.npz
│ └── final_results.json
└── logs/
└── *.log
配置参考
所有 YAML 配置均可通过 CLI 参数覆盖。运行 vit get --path ./configs 导出完整注释版模板。
model_config.yaml
| 配置项 | 默认值 | 说明 |
|---|---|---|
dim |
512 | Transformer 嵌入维度 |
depth |
6 | Transformer 编码器层数 |
heads |
8 | 注意力头数 |
dim_head |
64 | 每头维度 |
mlp_dim |
2048 | MLP 隐藏维度 |
dropout |
0.1 | Transformer 内 dropout |
patch_size |
[16, 16] | Patch 尺寸 |
input_shape |
[16, 368] | 输入图像尺寸 |
training_config.yaml
| 配置项 | 默认值 | 说明 |
|---|---|---|
epochs |
100 | 训练轮数 |
learning_rate |
0.001 | 学习率 |
optimizer |
AdamW | 优化器 |
scheduler |
CosineAnnealingLR | 学习率调度器 |
batch_size |
32 | 批次大小 |
device |
cuda | 训练设备 |
train_split.method |
stratified | 数据分割方法 |
train_split.train_ratio |
0.75 | 训练集比例 |
train_split.val_ratio |
0.15 | 验证集比例 |
train_split.test_ratio |
0.10 | 测试集比例 |
multi_axis_config.yaml
| 配置项 | 默认值 | 说明 |
|---|---|---|
enabled |
true | 启用多轴路径 |
active_axes |
[x, y, z] | 活跃轴列表 |
voting_weights |
{x: 0.3, y: 0.3, z: 0.4} | 轴投票权重 |
images_per_axis |
4 | 每轴 PNG 数量 |
knn_k |
5 | KNN 的 K 值 |
knn_metric |
cosine | KNN 距离度量 |
monitoring_config.yaml
| 配置项 | 默认值 | 说明 |
|---|---|---|
early_stopping.enabled |
true | 启用早停 |
early_stopping.metric |
val_loss | 早停监控指标 |
early_stopping.patience |
10 | 早停耐心轮数 |
model_selection.metric |
predict_accuracy | 最佳模型选择指标 |
predict_threshold |
0.5 | 患者级预测阈值 |
verbose |
true | 终端详细输出 |
终端输出规则
- 不加
--verbose:训练/预测过程中终端不输出实时信息,仅在流程结束后输出总结 - 加
--verbose:显示一根总 tqdm 进度条,详细调试信息仅写入日志文件
开发与构建
本地开发安装
conda activate Vit
pip install -e .
vit --help
构建发行包
conda run -n Vit python -m build
产物位于 dist/ 目录:btvit_cli-{version}-py3-none-any.whl 和 btvit-cli-{version}.tar.gz。
发布到 PyPI
# 安装发布工具(首次)
pip install twine
# 构建
python -m build
# 校验
twine check dist/*
# 上传到 TestPyPI(可选预览)
twine upload --repository testpypi dist/*
# 上传到 PyPI
twine upload dist/*
发布前检查清单:
- 更新
src/__version__中的版本号 - 确认
pyproject.toml中的name = "btvit-cli"和description正确 - 运行
python -m build构建无报错 - 运行
twine check dist/*校验通过 - 如需测试,先上传到 TestPyPI 并验证安装
项目结构
src/
├── __init__.py # __version__ = "0.1.2"
├── cli/
│ ├── main.py # Typer 入口:train / pred / get / version
│ ├── train_cmd.py # train 命令参数定义
│ ├── train_impl.py # 训练核心逻辑(多轴 + 单轴分发)
│ ├── train_runner.py # 训练执行包装
│ ├── predict_cmd.py # pred 命令参数定义
│ ├── predict_impl.py # 预测核心逻辑(checkpoint 驱动)
│ ├── predict_runner.py # 预测执行包装
│ ├── script_loader.py # 旧脚本兼容加载
│ └── utils.py # YAML 模板、数据布局归一化
├── model/
│ ├── vit_model.py # MedicalViT + MultiAxisMedicalViT
│ └── vit_nd_rotary.py # Golden Gate RoPE
├── data/
│ ├── preprocessing.py # ImageProcessor
│ ├── dataset.py # 单轴/多轴 Dataset + Collator
│ ├── enhanced_splitting.py # 患者级分层/随机分割
│ ├── memory_preloader.py # RAM 预加载
│ ├── improved_dataloader.py # DataLoaderManager
│ └── testset_export.py # 测试集导出
├── training/
│ ├── trainer.py # 单轴训练器
│ └── multi_axis_trainer.py # 多轴训练器
├── prediction/
│ ├── feature_store.py # 特征存储
│ ├── knn_classifier.py # KNN 分类
│ └── multi_axis_predictor.py # 多轴预测器
└── utils/
├── config.py # YAML 配置系统
├── metrics.py # Accuracy / Precision / Recall / F1 / AUC
├── logger.py # 日志管理
└── monitor.py # TensorBoard
环境要求
- Python >= 3.11
- PyTorch >= 2.6.0
- CUDA(可选,用于 GPU 加速)
许可证
Apache-2.0
作者
SuShuHeng
链接
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file btvit_cli-0.1.2.post4.tar.gz.
File metadata
- Download URL: btvit_cli-0.1.2.post4.tar.gz
- Upload date:
- Size: 123.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
81b73b07aa472ed4e1c4831538a0d5cae9f4cae552147b5c177e3476d5ee82e4
|
|
| MD5 |
a403c710964193876b91877f004782a7
|
|
| BLAKE2b-256 |
30d22d43eb28434c6857b6512a1d5c1b38c34f545da15690304c03b83f5ba755
|
File details
Details for the file btvit_cli-0.1.2.post4-py3-none-any.whl.
File metadata
- Download URL: btvit_cli-0.1.2.post4-py3-none-any.whl
- Upload date:
- Size: 115.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ea59619bf92f28d80111655d382a5d014c7316770c7908afa1085a3eaee4bbf2
|
|
| MD5 |
b4b596fcfdf0bdd15dffe3473b80d26f
|
|
| BLAKE2b-256 |
249a30ae8b1ca6aed47608da54a883b41b23b9ba24ffeebf07557e494fd0ef77
|