Skip to main content

BIT ININ 课题组自用服务器 GPU 资源调度及管理工具

Project description

app.py

app.py

from future import annotations

import importlib from pathlib import Path

from flask import Flask, Response, jsonify from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker

importlib.import_module("polar_flow.server.models")

from polar_flow.server.auth import auth_bp, login_manager, set_session_factory # noqa: E402 from polar_flow.server.config import Config # noqa: E402 from polar_flow.server.models import Base # noqa: E402 from polar_flow.server.schemas import UserRead # noqa: E402

-------- App Factory --------

def create_app(config_path: str) -> Flask: app = Flask(name)

# 1) 加载配置
cfg = Config.load(Path(config_path) if config_path else Path("config.toml"))
print(cfg)
app.config["SECRET_KEY"] = cfg.server.secret_key

# 2) 初始化数据库(Engine / Session 工厂)
engine = create_engine(cfg.server.database_url, future=True)
session_local: sessionmaker[Session] = sessionmaker(
    bind=engine,
    autoflush=False,
    autocommit=False,
    expire_on_commit=False,
    future=True,
)

# 如需自动建表(开发阶段可用,生产建议迁移脚本)
Base.metadata.create_all(bind=engine)

# 3) 注入会话工厂给认证模块,注册 Flask-Login
set_session_factory(session_local)
login_manager.init_app(app)

# 4) 注册蓝图
from polar_flow.server.routes import (  # noqa: PLC0415
    api_bp,
    set_session_factory as routes_set_session_factory,
)
routes_set_session_factory(session_local)
app.register_blueprint(auth_bp)
app.register_blueprint(api_bp)

# 5) 演示路由:健康检查
@app.get("/healthz")
def healthz() -> Response:
    return jsonify({"status": "ok"})

# 6) 演示路由:查看当前用户信息(需要登录)
from flask_login import current_user, login_required  # noqa: PLC0415

@app.get("/me")
@login_required
def me() -> Response:
    return jsonify(UserRead.model_validate(current_user).model_dump())

return app

def main() -> None: app = create_app("data/config.toml") # 生产环境请使用 WSGI/ASGI 服务器;这里用于本地开发 app.run(host="0.0.0.0", port=5000, debug=True)

config.py

from future import annotations

import logging import os from typing import TYPE_CHECKING

import toml from pydantic import BaseModel, Field, ValidationError, field_validator

if TYPE_CHECKING: from pathlib import Path

logger = logging.getLogger(name)

class ServerConfig(BaseModel): secret_key: str = Field(...) database_url: str = Field(...) redis_url: str = Field(...) scheduler_poll_interval: int = Field( ..., gt=0, description="轮询间隔(秒),必须大于 0", )

@field_validator("scheduler_poll_interval", mode="after")
@classmethod
def check_positive_interval(cls, v: int) -> int:
    if v <= 0:
        raise ValueError("scheduler_poll_interval 必须大于 0")
    return v

class DefaultsConfig(BaseModel): user_priority: int = Field( default=100, ge=0, description="默认普通用户提交任务可用的最大优先级(>= 0)", )

@field_validator("user_priority", mode="after")
@classmethod
def check_non_negative(cls, v: int) -> int:
    if v < 0:
        raise ValueError("user_priority 必须大于等于 0")
    return v

class Config(BaseModel): server: ServerConfig defaults: DefaultsConfig

@classmethod
def load(cls, config_path: Path) -> Config:
    data = {}
    if config_path.exists():
        try:
            data = toml.load(config_path)
        except Exception as e:  # noqa: BLE001
            logger.warning(f"无法加载配置文件 {config_path}: {e}; 将使用默认配置")
    else:
        logger.warning(f"未找到 {config_path}; 将使用默认配置")

    basedir = config_path.parent

    server_data = data.get("server", {})

    secret_key = server_data.get(
        "secret_key",
        os.environ.get("SECRET_KEY", "you-will-never-guess"),
    )

    database_url = server_data.get("database_url") or os.environ.get("DATABASE_URL")
    if not database_url:
        database_url = f"sqlite:///{(basedir / 'app.db').as_posix()}"

    redis_url = server_data.get("redis_url") or os.environ.get(
        "REDIS_URL", "redis://localhost:6379/0",
    )

    # 统一解析 scheduler_poll_interval
    spi_raw = server_data.get("scheduler_poll_interval")
    if spi_raw is None:
        spi_raw = os.environ.get("SCHEDULER_POLL_INTERVAL")
    try:
        spi = int(spi_raw) if spi_raw is not None else 5
    except ValueError:
        logger.warning(
            f"scheduler_poll_interval 无法解析为整数: {spi_raw}, 使用默认 5",
        )
        spi = 5
    try:
        spi = int(spi_raw) if spi_raw is not None else 5
    except ValueError:
        logger.warning(
            f"scheduler_poll_interval 无法解析为整数: {spi_raw}, 使用默认 5",
        )
        spi = 5

    # 统一解析 user_priority
    up_raw = (data.get("defaults", {}) or {}).get("user_priority")
    try:
        up = int(up_raw) if up_raw is not None else 100
    except ValueError:
        logger.warning(f"user_priority 无法解析为整数: {up_raw}, 使用默认 100")
        up = 100

    try:
        return cls(
            server=ServerConfig(
                secret_key=secret_key,
                database_url=database_url,
                redis_url=redis_url,
                scheduler_poll_interval=spi,
            ),
            defaults=DefaultsConfig(user_priority=up),
        )
    except ValidationError:
        logger.exception("配置文件解析错误")
        raise

server/gpu_monitor.py

from future import annotations

import logging import time from typing import TypedDict

from pynvml import ( NVMLError, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlDeviceGetUtilizationRates, nvmlInit, )

logger = logging.getLogger(name)

class GPUInfo(TypedDict): id: int memory_total: int memory_free: int memory_used: int util_gpu: int util_mem: int

def get_all_gpu_info() -> list[GPUInfo]: """ 返回所有 GPU 的状态列表 """ try: nvmlInit() except NVMLError: logger.exception("初始化 GPU 失败") return []

gpu_count = nvmlDeviceGetCount()
infos: list[GPUInfo] = []
for i in range(gpu_count):
    handle = nvmlDeviceGetHandleByIndex(i)
    mem = nvmlDeviceGetMemoryInfo(handle)
    util = nvmlDeviceGetUtilizationRates(handle)
    info = GPUInfo(
        id=i,
        memory_total=int(mem.total),
        memory_free=int(mem.free),
        memory_used=int(mem.used),
        util_gpu=int(util.gpu),
        util_mem=int(util.memory),
    )
    infos.append(info)
return infos

def monitor_loop(poll_interval: float = 5.0) -> None: """ 后台线程 /进程执行 GPU 信息监控, 定期(默认每 poll_interval 秒)采集并可供调度器 /网页 UI 查询 """ while True: infos = get_all_gpu_info() # TODO 把 infos 存到全局缓存 /共享状态里 print("GPU infos:", infos) time.sleep(poll_interval)

server/models.py

from future import annotations

import datetime as dt from enum import Enum

from flask_login import UserMixin from sqlalchemy import JSON, DateTime, Enum as SAEnum, ForeignKey, Integer, String, Text from sqlalchemy.ext.mutable import MutableList from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship

class Base(DeclarativeBase): pass

class Role(Enum): USER = "user" ADMIN = "admin"

class TaskStatus(Enum): PENDING = "PENDING" RUNNING = "RUNNING" SUCCESS = "SUCCESS" FAILED = "FAILED" CANCELLED = "CANCELLED"

class User(Base, UserMixin): tablename = "users"

id: Mapped[int] = mapped_column(primary_key=True)
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
password_hash: Mapped[str] = mapped_column(String(128), nullable=False)
role: Mapped[Role] = mapped_column(SAEnum(Role), default=Role.USER, nullable=False)
visible_gpus: Mapped[list[int]] = mapped_column(
    MutableList.as_mutable(JSON),
    default=list,  # 注意:用可调用对象,避免共享同一个列表
    nullable=False,
)
priority: Mapped[int] = mapped_column(Integer, default=100, nullable=False)

# 注意:前向引用用字符串,避免静态类型检查报错
tasks: Mapped[list[Task]] = relationship(
    "Task",
    back_populates="user",
    cascade="all, delete-orphan",
)

def set_password(self, raw: str) -> None:
    from werkzeug.security import generate_password_hash  # noqa: PLC0415

    self.password_hash = generate_password_hash(raw)

def check_password(self, raw: str) -> bool:
    from werkzeug.security import check_password_hash  # noqa: PLC0415

    return check_password_hash(self.password_hash, raw)

def get_visible_gpus_list(self) -> list[int]:
    return self.visible_gpus

class Task(Base): tablename = "tasks"

id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
user_id: Mapped[int] = mapped_column(
    Integer,
    ForeignKey("users.id"),
    nullable=False,
    index=True,
)
user: Mapped[User] = relationship("User", back_populates="tasks")

name: Mapped[str] = mapped_column(String(128), nullable=False)
command: Mapped[str] = mapped_column(String(512), nullable=False)
requested_gpus: Mapped[str] = mapped_column(String(64), nullable=False)  # "0,1" 或 "AUTO:2"
gpu_memory_limit: Mapped[int | None] = mapped_column(Integer, nullable=True)  # MB
priority: Mapped[int] = mapped_column(Integer, default=100, nullable=False)

working_dir: Mapped[str] = mapped_column(String(256), nullable=False)

status: Mapped[TaskStatus] = mapped_column(
    SAEnum(TaskStatus),
    default=TaskStatus.PENDING,
    nullable=False,
)

# 使用时区感知时间(UTC),并设置 timezone=True
created_at: Mapped[dt.datetime] = mapped_column(
    DateTime(timezone=True),
    default=lambda: dt.datetime.now(dt.UTC),
    nullable=False,
)
started_at: Mapped[dt.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
finished_at: Mapped[dt.datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)

stdout_log: Mapped[str | None] = mapped_column(Text, nullable=True)
stderr_log: Mapped[str | None] = mapped_column(Text, nullable=True)

server/scheduler.py

from future import annotations

import datetime as dt import os import subprocess import time from collections.abc import Callable from typing import TYPE_CHECKING

from sqlalchemy import select from sqlalchemy.orm import Session, joinedload

from polar_flow.server.gpu_monitor import get_all_gpu_info from polar_flow.server.models import Role, Task, TaskStatus

if TYPE_CHECKING: from sqlalchemy.orm import sessionmaker

SessionFactory = Callable[[], Session]

def resources_available(requested: list[int], gpu_memory_limit: int | None) -> bool: """ 检查给定 GPU 是否有足够的可用显存。 NVML 返回的是字节,这里将 gpu_memory_limit(单位 MB) 转换为字节后比较。 """ infos = get_all_gpu_info() free_map: dict[int, int] = {g["id"]: g["memory_free"] for g in infos} # bytes

for gid in requested:
    free_bytes = free_map.get(gid)
    if free_bytes is None:
        return False
    if gpu_memory_limit is not None:
        required_bytes = gpu_memory_limit * 1024 * 1024  # MB -> bytes
        if free_bytes < required_bytes:
            return False
return True

def _select_gpus(task: Task) -> list[int]: if task.requested_gpus.startswith("AUTO:"): num = int(task.requested_gpus.split(":", 1)[1]) infos = get_all_gpu_info()

    # 注意:NVML 是字节,这里做单位换算
    limit_bytes = None
    if task.gpu_memory_limit is not None:
        limit_bytes = task.gpu_memory_limit * 1024 * 1024

    candidates = [g for g in infos if (limit_bytes is None or g["memory_free"] >= limit_bytes)]
    if len(candidates) < num:
        return []
    selected = [
        g["id"] for g in sorted(candidates, key=lambda x: x["memory_free"], reverse=True)[:num]
    ]
else:
    selected = [int(x) for x in task.requested_gpus.split(",") if x.strip() != ""]
return selected

def allocate_and_run_task(task: Task, session_local: SessionFactory) -> bool: session: Session = session_local() try: # 在当前 session 中把 task 捞出来(顺便把 user 一并 eager load,避免再次懒加载) task_db = session.execute( select(Task).options(joinedload(Task.user)).where(Task.id == task.id), ).scalar_one_or_none() if task_db is None: return False

    selected = _select_gpus(task_db)
    if not selected:
        return False

    # 用户 GPU 权限检查(非管理员走白名单)
    user = task_db.user
    if user.role != Role.ADMIN:
        visible = set(user.get_visible_gpus_list())
        if not all(gid in visible for gid in selected):
            return False

    if not resources_available(selected, task_db.gpu_memory_limit):
        return False

    # 状态更新为 RUNNING
    task_db.status = TaskStatus.RUNNING
    task_db.started_at = dt.datetime.now(dt.UTC)
    session.commit()

    env = os.environ.copy()
    env["CUDA_VISIBLE_DEVICES"] = ",".join(str(x) for x in selected)

    proc = subprocess.Popen(
        task_db.command,
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        cwd=task_db.working_dir or os.getcwd(),
        env=env,
        text=True,
    )
    out, err = proc.communicate()

    task_db.finished_at = dt.datetime.now(dt.UTC)
    task_db.stdout_log = out
    task_db.stderr_log = err
    task_db.status = TaskStatus.SUCCESS if proc.returncode == 0 else TaskStatus.FAILED
    session.commit()
except Exception:
    session.rollback()
    raise
else:
    return True
finally:
    session.close()

def scheduler_loop(poll_interval: float, session_local: sessionmaker[Session]) -> None: """ 调度器主循环:查找 PENDING 任务,按 priority(降序)和 created_at(升序)调度。 """ while True: session: Session = session_local() try: tasks = ( session.query(Task) .filter(Task.status == TaskStatus.PENDING) .order_by(Task.priority.desc(), Task.created_at.asc()) .all() ) for task in tasks: if allocate_and_run_task(task, session_local): continue # TODO 分配失败:可能资源不够或权限不足,留待下轮 finally: session.close() time.sleep(poll_interval)

server/schemas.py

from future import annotations

from typing import TYPE_CHECKING

from pydantic import BaseModel, ConfigDict, Field

from polar_flow.server.models import Role, TaskStatus # noqa: TC001

if TYPE_CHECKING: import datetime as dt

class UserCreate(BaseModel): username: str = Field(..., min_length=1, max_length=64) password: str = Field(..., min_length=6)

model_config = ConfigDict(extra="forbid")

class UserRead(BaseModel): id: int username: str role: Role visible_gpus: list[int] = Field(default_factory=list) priority: int

model_config = ConfigDict(from_attributes=True, use_enum_values=True)

class TaskCreate(BaseModel): name: str = Field(..., min_length=1, max_length=128) command: str = Field(..., min_length=1) requested_gpus: str = Field(..., min_length=1) working_dir: str = Field(..., min_length=1, max_length=256) gpu_memory_limit: int | None = Field(default=None, ge=0) priority: int = Field(default=100, ge=0)

model_config = ConfigDict(extra="forbid")

class TaskRead(BaseModel): id: int user_id: int name: str command: str requested_gpus: str working_dir: str gpu_memory_limit: int | None priority: int status: TaskStatus created_at: dt.datetime started_at: dt.datetime | None finished_at: dt.datetime | None stdout_log: str | None stderr_log: str | None

model_config = ConfigDict(from_attributes=True)

server/worker.py

from future import annotations

from pathlib import Path

from polar_flow.server.config import Config from polar_flow.server.db import create_session_factory from polar_flow.server.models import Base from polar_flow.server.scheduler import scheduler_loop

def run_worker(config_path: str | None = None) -> None: cfg = Config.load(Path(config_path)) if config_path else Config.load(Path("config.toml")) poll_interval = cfg.server.scheduler_poll_interval session_local, engine = create_session_factory(cfg.server.database_url) Base.metadata.create_all(engine) # ensure tables exist scheduler_loop(poll_interval=poll_interval, session_local=session_local)

server/db.py

from future import annotations

from sqlalchemy import Engine, create_engine from sqlalchemy.orm import Session, sessionmaker

def create_session_factory(database_url: str) -> tuple[sessionmaker[Session], Engine]: """Helper: 创建 SQLAlchemy session 工厂与 engine。 在 worker 与 app 两边均可重用。 """ engine = create_engine(database_url, future=True) session_local: sessionmaker[Session] = sessionmaker( bind=engine, autoflush=False, autocommit=False, expire_on_commit=False, future=True, ) return session_local, engine

server/auth.py

from future import annotations

from functools import wraps from typing import TYPE_CHECKING

from flask import Blueprint, Response, jsonify, request from flask_login import LoginManager, current_user, login_required, login_user, logout_user

from polar_flow.server.models import Role, User from polar_flow.server.schemas import UserRead

if TYPE_CHECKING: from collections.abc import Callable

from flask.typing import ResponseReturnValue
from sqlalchemy.orm import Session, sessionmaker

---- Flask-Login 基础对象 ----

auth_bp = Blueprint("auth", name) login_manager = LoginManager()

@login_manager.unauthorized_handler def _unauthorized(): # noqa: ANN202 return jsonify({"error": "login required"}), 401

---- 会话工厂注入 ----

_session_factory: sessionmaker[Session] | None = None

def set_session_factory(session_factory: sessionmaker[Session]) -> None: """在应用初始化阶段调用,一次性注入会话工厂。""" global _session_factory # noqa: PLW0603 _session_factory = session_factory

def _get_session() -> Session: if _session_factory is None: raise RuntimeError("Session factory is not initialized. Call set_session_factory() first.") return _session_factory()

def get_user_by_username(username: str) -> User | None: session = _get_session() try: return session.query(User).filter(User.username == username).first() finally: session.close()

@auth_bp.route("/auth/login", methods=["POST"]) def login() -> tuple[Response, int]: data = request.json or {} username = data.get("username") password = data.get("password") if not username or not password: return jsonify({"error": "username and password required"}), 400

user = get_user_by_username(username)
if user is None or not user.check_password(password):
    return jsonify({"error": "invalid credentials"}), 401

login_user(user)

print(user.visible_gpus)

return jsonify(
    {"message": "logged in", "user": UserRead.model_validate(user).model_dump()},
), 200

@auth_bp.route("/auth/logout", methods=["POST"]) @login_required def logout() -> tuple[Response, int]: logout_user() return jsonify({"message": "logged out"}), 200

def admin_required[**P](func: Callable[P, ResponseReturnValue]) -> Callable[P, ResponseReturnValue]: """ 管理员权限校验装饰器: - 使用 ParamSpec 保留被装饰函数的参数签名(*args/**kwargs 的类型信息) - 返回类型采用 Flask 的 ResponseReturnValue(str | bytes | Response | (Response, status) ...) """

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> ResponseReturnValue:
    if not current_user.is_authenticated:
        return jsonify({"error": "login required"}), 401
    if getattr(current_user, "role", None) != Role.ADMIN:
        return jsonify({"error": "admin required"}), 403
    return func(*args, **kwargs)

return wrapper

@login_manager.user_loader def load_user(user_id: str) -> User | None: try: uid = int(user_id) except (TypeError, ValueError): return None session = _get_session() try: return session.get(User, uid) finally: session.close()

server/routes.py

from future import annotations

import datetime as dt from pathlib import Path from typing import TYPE_CHECKING

from flask import Blueprint, Response, jsonify, request from flask_login import current_user, login_required

from .auth import admin_required from .models import Role, Task, TaskStatus, User from .schemas import TaskCreate, TaskRead, UserCreate, UserRead

if TYPE_CHECKING: from sqlalchemy.orm import Session, sessionmaker

通过 app.py 在应用初始化阶段注入 session factory

_session_factory: sessionmaker[Session] | None = None

def set_session_factory(session_factory: sessionmaker[Session]) -> None: global _session_factory # noqa: PLW0603 _session_factory = session_factory

def _get_session() -> Session: if _session_factory is None: raise RuntimeError("routes: Session factory is not initialized") return _session_factory()

api_bp = Blueprint("api", name, url_prefix="/api")

---------- GPU 可见性与健康 ----------

@api_bp.get("/gpus") @login_required def list_gpus() -> Response: from .gpu_monitor import get_all_gpu_info # 延迟导入避免 NVML 成本 # noqa: PLC0415

infos = get_all_gpu_info()
return jsonify(infos)

---------- 任务 CRUD(当前用户域) ----------

@api_bp.post("/tasks") @login_required def create_task() -> tuple[Response, int]: data = request.json or {} try: payload = TaskCreate.model_validate(data) except Exception as e: # noqa: BLE001 return jsonify({"error": f"invalid payload: {e}"}), 400

# 基础校验
if payload.requested_gpus.startswith("AUTO:"):
    try:
        n = int(payload.requested_gpus.split(":", 1)[1])
    except Exception:  # noqa: BLE001
        return jsonify({"error": "requested_gpus AUTO:<n> 格式错误"}), 400
    if n <= 0:
        return jsonify({"error": "AUTO 台数必须 > 0"}), 400
else:
    try:
        _ = [int(x) for x in payload.requested_gpus.split(",") if x.strip()]
    except Exception:  # noqa: BLE001
        return jsonify({"error": "requested_gpus 需为 '0,1' 或 'AUTO:n'"}), 400

# 非管理员不可越权设定优先级
priority = payload.priority
if current_user.role != Role.ADMIN and priority > current_user.priority:
    priority = current_user.priority

# working_dir 必须存在
if not Path(payload.working_dir).exists():
    return jsonify({"error": f"working_dir 不存在: {payload.working_dir}"}), 400

sess = _get_session()
try:
    task = Task(
        user_id=current_user.id,
        name=payload.name,
        command=payload.command,
        requested_gpus=payload.requested_gpus,
        gpu_memory_limit=payload.gpu_memory_limit,
        priority=priority,
        working_dir=str(Path(payload.working_dir).resolve()),
        status=TaskStatus.PENDING,
    )
    sess.add(task)
    sess.commit()
    sess.refresh(task)
    return jsonify(TaskRead.model_validate(task).model_dump()), 201
finally:
    sess.close()

@api_bp.get("/tasks") @login_required def list_tasks() -> tuple[Response, int] | Response: """列出当前用户的任务;管理员可查看全部并按用户过滤。""" user_id = request.args.get("user_id", type=int) status = request.args.get("status")

sess = _get_session()
try:
    q = sess.query(Task)
    if current_user.role != Role.ADMIN:
        q = q.filter(Task.user_id == current_user.id)
    elif user_id:
        q = q.filter(Task.user_id == user_id)
    if status:
        try:
            st = TaskStatus(status)
            q = q.filter(Task.status == st)
        except Exception:  # noqa: BLE001
            return jsonify({"error": "status 无效"}), 400
    q = q.order_by(Task.created_at.desc())
    items = q.all()
    return jsonify([TaskRead.model_validate(t).model_dump() for t in items])
finally:
    sess.close()

@api_bp.get("/tasks/int:task_id") @login_required def get_task(task_id: int) -> tuple[Response, int]: sess = _get_session() try: t = sess.get(Task, task_id) if not t: return jsonify({"error": "not found"}), 404 if current_user.role != Role.ADMIN and t.user_id != current_user.id: return jsonify({"error": "forbidden"}), 403 return jsonify(TaskRead.model_validate(t).model_dump()), 200 finally: sess.close()

@api_bp.post("/tasks/int:task_id/cancel") @login_required def cancel_task(task_id: int) -> tuple[Response, int]: sess = _get_session() try: t = sess.get(Task, task_id) if not t: return jsonify({"error": "not found"}), 404 if current_user.role != Role.ADMIN and t.user_id != current_user.id: return jsonify({"error": "forbidden"}), 403 if t.status in (TaskStatus.SUCCESS, TaskStatus.FAILED, TaskStatus.CANCELLED): return jsonify({"message": f"task already {t.status.value}"}), 200 # 这里只是把状态改为 CANCELLED;运行中进程的终止由更完善的执行器实现。 t.status = TaskStatus.CANCELLED t.finished_at = dt.datetime.now(dt.UTC) sess.commit() return jsonify({"message": "cancelled"}), 200 finally: sess.close()

---------- 用户管理(仅管理员) ----------

@api_bp.post("/admin/users") @admin_required def create_user() -> tuple[Response, int]: data = request.json or {} try: payload = UserCreate.model_validate(data) except Exception as e: # noqa: BLE001 return jsonify({"error": f"invalid payload: {e}"}), 400

sess = _get_session()
try:
    if sess.query(User).filter(User.username == payload.username).first():
        return jsonify({"error": "username exists"}), 409
    u = User(
        username=payload.username,
        role=Role.USER,
        priority=100,
        visible_gpus=[],
    )
    u.set_password(payload.password)
    sess.add(u)
    sess.commit()
    sess.refresh(u)
    return jsonify(UserRead.model_validate(u).model_dump()), 201
finally:
    sess.close()

@api_bp.patch("/admin/users/int:user_id") @admin_required def patch_user(user_id: int) -> tuple[Response, int]: data = request.json or {} sess = _get_session() try: u = sess.get(User, user_id) if not u: return jsonify({"error": "not found"}), 404 # 允许修改:role, priority, visible_gpus, password if "role" in data: try: u.role = Role(data["role"]) # type: ignore[assignment] except Exception: # noqa: BLE001 return jsonify({"error": "role must be 'user'|'admin'"}), 400 if "priority" in data: try: p = int(data["priority"]) if p < 0: raise ValueError # noqa: TRY301 u.priority = p except Exception: # noqa: BLE001 return jsonify({"error": "priority must be >= 0"}), 400 if "visible_gpus" in data: v = data["visible_gpus"] if not isinstance(v, list) or not all(isinstance(x, int) for x in v): return jsonify({"error": "visible_gpus must be int list"}), 400 u.visible_gpus = v if "password" in data: u.set_password(str(data["password"])) sess.commit() return jsonify(UserRead.model_validate(u).model_dump()), 200 finally: sess.close()

@api_bp.get("/admin/users") @admin_required def list_users() -> Response: sess = _get_session() try: items = sess.query(User).order_by(User.id.asc()).all() return jsonify([UserRead.model_validate(u).model_dump() for u in items]) finally: sess.close()

cli/entry.py

from future import annotations

import json import os from pathlib import Path from typing import Optional

import colorama import toml import typer import requests

app = typer.Typer(add_completion=False, help="BIT ININ 课题组自用 服务器 GPU 资源分配器")

DEFAULT_BASE_URL = os.environ.get("POLAR_BASE_URL", "http://127.0.0.1:5000") STATE_DIR = Path(os.environ.get("POLAR_STATE_DIR", "~/.polar_flow")).expanduser() COOKIE_FILE = STATE_DIR / "cookies.txt"

class Client: def init(self, base_url: str = DEFAULT_BASE_URL) -> None: self.base_url = base_url.rstrip("/") self.session = requests.Session() STATE_DIR.mkdir(parents=True, exist_ok=True) if COOKIE_FILE.exists(): try: self.session.cookies.update( requests.utils.cookiejar_from_dict(json.loads(COOKIE_FILE.read_text())) ) except Exception: pass

def _save_cookies(self) -> None:
    COOKIE_FILE.write_text(json.dumps(requests.utils.dict_from_cookiejar(self.session.cookies)))

# ---- Auth ----
def login(self, username: str, password: str) -> dict:
    r = self.session.post(
        f"{self.base_url}/auth/login", json={"username": username, "password": password}
    )
    try:
        r.raise_for_status()
    except requests.HTTPError as e:
        # 尝试把服务端返回的错误正文打印出来
        try:
            err = r.json().get("error")
        except Exception:
            err = r.text
        raise SystemExit(f"{colorama.Fore.BLUE}[登录失败]: {colorama.Fore.RED}{err} ({e}){colorama.Style.RESET_ALL}")

def logout(self) -> dict:
    r = self.session.post(f"{self.base_url}/auth/logout")
    r.raise_for_status()
    self._save_cookies()
    return r.json()

# ---- Tasks ----
def create_task(self, payload: dict) -> dict:
    r = self.session.post(f"{self.base_url}/api/tasks", json=payload)
    r.raise_for_status()
    return r.json()

def list_tasks(self, status: Optional[str] = None) -> list[dict]:
    params = {"status": status} if status else None
    r = self.session.get(f"{self.base_url}/api/tasks", params=params)
    r.raise_for_status()
    return r.json()

def get_task(self, task_id: int) -> dict:
    r = self.session.get(f"{self.base_url}/api/tasks/{task_id}")
    r.raise_for_status()
    return r.json()

def cancel_task(self, task_id: int) -> dict:
    r = self.session.post(f"{self.base_url}/api/tasks/{task_id}/cancel")
    r.raise_for_status()
    return r.json()

def list_gpus(self) -> list[dict]:
    r = self.session.get(f"{self.base_url}/api/gpus")
    try:
        r.raise_for_status()
    except requests.HTTPError as e:
        # 尝试把服务端返回的错误正文打印出来
        try:
            err = r.json().get("error")
        except Exception:
            err = r.text
        raise SystemExit(f"{colorama.Fore.BLUE}[查询失败]: {colorama.Fore.RED}{err} ({e}){colorama.Style.RESET_ALL}")
    return r.json()

---------------- CLI commands ----------------

@app.command() def login( username: str = typer.Option(..., "--username", "-u"), password: str = typer.Option(..., "--password", "-p", prompt=True, hide_input=True), base_url: str = typer.Option(DEFAULT_BASE_URL, "--base-url"), ): """登录并保存会话 Cookie。""" c = Client(base_url) res = c.login(username, password) typer.echo(f"Logged in as {res['user']['username']}")

@app.command() def logout(base_url: str = typer.Option(DEFAULT_BASE_URL, "--base-url")): """注销当前会话。""" c = Client(base_url) try: res = c.logout() typer.echo(res.get("message", "logged out")) except requests.HTTPError as e: typer.echo(f"Logout failed: {e}")

@app.command("gpus") def gpus_cmd(base_url: str = typer.Option(DEFAULT_BASE_URL, "--base-url")): """查看 GPU 状态。""" c = Client(base_url) infos = c.list_gpus() for g in infos: typer.echo(json.dumps(g, ensure_ascii=False))

@app.command("submit") def submit_cmd( config: Path = typer.Option( ..., "--config", "-c", exists=True, readable=True, help="TOML 任务配置文件" ), base_url: str = typer.Option(DEFAULT_BASE_URL, "--base-url"), ): """从 TOML 提交任务。""" data = toml.load(config) # 支持:task.name, task.command, task.requested_gpus, task.working_dir, task.gpu_memory_limit, task.priority t = data.get("task", {}) payload = { "name": t.get("name"), "command": t.get("command"), "requested_gpus": t.get("requested_gpus", "AUTO:1"), "working_dir": t.get("working_dir", str(Path.cwd())), "gpu_memory_limit": t.get("gpu_memory_limit"), "priority": t.get("priority", 100), } c = Client(base_url) res = c.create_task(payload) typer.echo(json.dumps(res, ensure_ascii=False, indent=2))

@app.command("ls") def list_cmd( status: Optional[str] = typer.Option( None, "--status", help="过滤任务状态(PENDING/RUNNING/SUCCESS/FAILED/CANCELLED)" ), base_url: str = typer.Option(DEFAULT_BASE_URL, "--base-url"), ): """列出我的任务。""" c = Client(base_url) items = c.list_tasks(status=status) for it in items: typer.echo( f"#{it['id']} [{it['status']}] {it['name']} prio={it['priority']} created={it['created_at']}" )

@app.command("logs") def logs_cmd( task_id: int = typer.Argument(...), base_url: str = typer.Option(DEFAULT_BASE_URL, "--base-url") ): """查看任务日志(stdout/stderr)。""" c = Client(base_url) t = c.get_task(task_id) typer.echo(f"== stdout ==\n{t.get('stdout_log') or ''}\n") typer.echo(f"== stderr ==\n{t.get('stderr_log') or ''}\n")

@app.command("cancel") def cancel_cmd( task_id: int = typer.Argument(...), base_url: str = typer.Option(DEFAULT_BASE_URL, "--base-url") ): c = Client(base_url) res = c.cancel_task(task_id) typer.echo(res.get("message", "ok"))

def main() -> None: app()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

polarflow-0.0.1.tar.gz (38.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

polarflow-0.0.1-py3-none-any.whl (36.8 kB view details)

Uploaded Python 3

File details

Details for the file polarflow-0.0.1.tar.gz.

File metadata

  • Download URL: polarflow-0.0.1.tar.gz
  • Upload date:
  • Size: 38.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for polarflow-0.0.1.tar.gz
Algorithm Hash digest
SHA256 fff71fd73cad6450d044ed8cc40a85066315c08b9a96ca2fa36a7199c742e463
MD5 32f967e04ef8c153cd2a70cdd1c68332
BLAKE2b-256 7bc004019c6218901701bd7562a9dcdc9a5d3846aa2f17b943cd66be40ec90bc

See more details on using hashes here.

File details

Details for the file polarflow-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: polarflow-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 36.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.0

File hashes

Hashes for polarflow-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 236539ca2b1cab9872cbcf184d93b3dda353bc59d9926e2c108725251333ec7c
MD5 95595221e79aad8572981e31430110af
BLAKE2b-256 ed7cdf0914e1c15d82174fd357e0cdc35c121e32b4b26225d6f00b2a7e46ac46

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page