Training utilities for keynet - MLflow and model training support
Project description
keynet-train
MLflow와 통합된 모델 훈련 유틸리티
설치
pip install keynet-train
주요 기능
🚀 자동화된 훈련 API
- 모델에서 자동으로 스키마 추론
- PyTorch 모델을 ONNX로 자동 변환
- MLflow에 자동 로깅 및 버전 관리
📊 지원 프레임워크
- PyTorch (TorchScript, ONNX 변환)
- ONNX (네이티브 지원)
- 다중 입력/출력 모델 지원
🔧 MLflow 통합
- 실험 자동 생성 및 관리
- 모델 아티팩트 자동 저장
- 메트릭 및 파라미터 추적
🚀 기본 사용법
from keynet_train import trace
import torch
# 🎯 decorator에 샘플 입력을 제공하고, 함수에서는 모델만 반환
@trace("my_experiment", torch.randn(1, 3, 224, 224))
def train_model():
model = MyModel()
# 학습 코드...
for epoch in range(10):
# 실제 학습 로직
pass
return model # ⚠️ 반드시 torch.nn.Module만 반환
📋 반환값 제약사항
@trace 데코레이터를 사용하는 함수는 반드시 torch.nn.Module 객체만 반환해야 합니다.
✅ 올바른 사용법
@trace("experiment", torch.randn(1, 784))
def train_mnist():
model = torch.nn.Sequential(
torch.nn.Linear(784, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
)
# 훈련 로직
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(100):
# 실제 훈련...
loss = train_one_epoch(model, optimizer, train_loader)
# 메트릭은 mlflow.log_* 함수로 기록
mlflow.log_metric("train_loss", loss, step=epoch)
return model # 🎯 모델만 반환!
❌ 잘못된 사용법들
@trace("experiment", torch.randn(1, 784))
def wrong_usage1():
model = MyModel()
loss = train(model)
return model, loss # ❌ 튜플 반환 불가
@trace("experiment", torch.randn(1, 784))
def wrong_usage2():
model = MyModel()
train(model)
return {
"model": model,
"accuracy": 0.95
} # ❌ 딕셔너리 반환 불가
@trace("experiment", torch.randn(1, 784))
def wrong_usage3():
model = MyModel()
train(model)
return "model_saved.pth" # ❌ 문자열 반환 불가
💡 왜 이런 제약이 있나요?
@trace 데코레이터는 내부적으로 다음 작업을 자동화합니다:
- MLflow 모델 로깅:
mlflow.pytorch.log_model(pytorch_model=model, ...) - ONNX 변환:
torch.onnx.export(model, ...) - Triton 배포: 자동
config.pbtxt생성
이 모든 작업이 torch.nn.Module 객체를 필요로 하므로, 다른 타입의 반환값은 지원하지 않습니다.
📝 ONNX 모델 입출력 파라미터명 규칙
@trace 데코레이터를 사용할 때 생성되는 ONNX 모델의 입출력 파라미터명은 다음과 같이 결정됩니다:
입력 파라미터 (Inputs)
# ✅ Dictionary 형태로 입력하면 키 이름을 사용 (권장)
@trace("my_experiment", {"image": torch.randn(1, 3, 224, 224), "label": torch.randn(1, 10)})
def train_model():
# 생성되는 ONNX의 입력명: "image", "label"
...
# ✅ 단일 텐서로 입력하면 자동 생성
@trace("my_experiment", torch.randn(1, 3, 224, 224))
def train_model():
# 생성되는 ONNX의 입력명: "input_0"
...
출력 파라미터 (Outputs)
# 출력명은 항상 자동 생성됩니다
@trace("my_experiment", torch.randn(1, 3, 224, 224))
def train_model():
# 단일 출력: "output_0"
return model
# 다중 출력 모델의 경우
def train_multi_output_model():
class MultiOutputModel(torch.nn.Module):
def forward(self, x):
return output1, output2 # 튜플 반환
# 실제로는 MLflow가 튜플을 하나의 배열로 처리하여 "output_0"만 생성됨
return model
⚠️ 중요한 제한사항
- 지원되는 입력 형태:
torch.Tensor또는Dict[str, torch.Tensor]만 지원 - 튜플 입력 미지원:
(tensor1, tensor2)형태의 튜플 입력은 현재 지원되지 않음 - 다중 출력 처리: PyTorch 모델이 튜플로 다중 출력을 반환해도 MLflow signature 추론에 의해
output_0하나로 처리됨 - MLflow 의존성: 파라미터명 생성은 MLflow의 자동 signature 추론에 의존하므로 일부 제한사항이 있음
💡 권장사항
# 🎯 최적의 사용법: Dictionary 입력으로 명시적인 이름 지정
@trace("experiment", {
"image": torch.randn(1, 3, 224, 224),
"mask": torch.randn(1, 1, 224, 224)
})
def train_model():
# 생성되는 config.pbtxt에서 명확한 입력명 확인 가능:
# input { name: "image", data_type: TYPE_FP32, dims: [-1, 3, 224, 224] }
# input { name: "mask", data_type: TYPE_FP32, dims: [-1, 1, 224, 224] }
return model
Note: 생성된 ONNX 모델은 Triton Inference Server 배포 시 자동으로
config.pbtxt파일이 생성되어 정확한 입출력 스키마를 확인할 수 있습니다.
다중 입력 모델
@trace("multi_input_exp", {
"image": torch.randn(1, 3, 224, 224),
"mask": torch.randn(1, 1, 224, 224)
})
def train_multi_input():
model = MultiInputModel()
# 모델이 여러 입력을 받는 경우
class MultiInputModel(torch.nn.Module):
def forward(self, image, mask):
# image와 mask를 함께 처리
combined = torch.cat([image, mask], dim=1)
return self.classifier(combined)
# 훈련 로직...
return model
라이선스
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file keynet_train-0.2.0.dev0.tar.gz.
File metadata
- Download URL: keynet_train-0.2.0.dev0.tar.gz
- Upload date:
- Size: 17.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.28.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
04fc03ef1dab9d6dca597065ac7e29b99ef297a7b4beae44a49d9d994d17c041
|
|
| MD5 |
d77b714bf79afd54f8db7ffcc8569852
|
|
| BLAKE2b-256 |
20383fb7fe733e20ee76be1bcb7047495dfa22dd584c94dc1a6612bc8abec368
|
File details
Details for the file keynet_train-0.2.0.dev0-py3-none-any.whl.
File metadata
- Download URL: keynet_train-0.2.0.dev0-py3-none-any.whl
- Upload date:
- Size: 21.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: python-httpx/0.28.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
addab98bdbfec3461013a2ee17034020fa86eb1871fedcfcdee92fb16d9b0e06
|
|
| MD5 |
c30003ad95569a6124aa6e2982ebf5de
|
|
| BLAKE2b-256 |
ceefbcd7b3baa131263bfc00022c12df3e3265c968829937f5c9b5b2a5727daf
|