Skip to main content

Efficient and customizable GPT-based framework for medical applications, enabling training, inference, and explainability of large language models in healthcare.

Project description

GptMed ๐Ÿค–

Downloads Downloads/Month PyPI version Python 3.8+ License: MIT

A lightweight GPT-based language model framework for training custom question-answering models on any domain. This package provides a transformer-based GPT architecture that you can train on your own Q&A datasets - whether it's casual conversations, technical support, education, or any other domain.

Citation

If you use this model in your research, please cite:

@software{gptmed_2026,
  author = {Sanjog Sigdel},
  title = {GptMed: A custom causal question answering general purpose GPT Transformer Architecture Model},
  year = {2026},
  url = {https://github.com/sigdelsanjog/gptmed}
}

Table of Contents

Installation

From PyPI (Recommended)

pip install gptmed

From Source

git clone https://github.com/sigdelsanjog/gptmed.git
cd gptmed
pip install -e .

With Optional Dependencies

# For development
pip install gptmed[dev]

# For training with logging integrations
pip install gptmed[training]

# For visualization (loss curves, metrics plots)
pip install gptmed[visualization]

# For Explainable AI features
pip install gptmed[xai]

# All dependencies
pip install gptmed[dev,training,visualization,xai]

Quick Start

Using the High-Level API

The easiest way to use GptMed is through the high-level API:

import gptmed

# 1. Create a training configuration
gptmed.create_config('my_config.yaml')

# 2. Edit my_config.yaml with your settings (data paths, model size, etc.)

# 3. Train the model
gptmed.train_from_config('my_config.yaml')

# 4. Generate answers
answer = gptmed.generate(
    checkpoint='model/checkpoints/best_model.pt',
    tokenizer='tokenizer/my_tokenizer.model',
    prompt='What is machine learning?',
    max_length=150,
    temperature=0.7
)
print(answer)

For a complete API testing workflow, see the gptmed-api folder with ready-to-run examples.

Inference (Generate Answers)

from gptmed.inference.generator import TextGenerator
from gptmed.model.architecture import GPTTransformer
from gptmed.model.configs.model_config import get_small_config

# Load model
config = get_small_config()
model = GPTTransformer(config)

# Load your trained checkpoint
# model.load_state_dict(torch.load('path/to/checkpoint.pt'))

# Create generator
generator = TextGenerator(
    model=model,
    tokenizer_path='path/to/tokenizer.model'
)

# Generate answer
question = "What's your favorite programming language?"
answer = generator.generate(
    prompt=question,
    max_length=100,
    temperature=0.7
)

print(f"Q: {question}")
print(f"A: {answer}")

Using Command Line

# Generate answers
gptmed-generate --prompt "How do I train a custom model?" --max-length 100

# Train model
gptmed-train --model-size small --num-epochs 10 --batch-size 16

Training Your Own Model

from gptmed.training.train import main
from gptmed.configs.train_config import get_default_config
from gptmed.model.configs.model_config import get_small_config

# Configure training
train_config = get_default_config()
train_config.batch_size = 16
train_config.num_epochs = 10
train_config.learning_rate = 3e-4

# Start training
main()

Model Architecture

The model uses a custom GPT-based transformer architecture:

  • Embedding: Token + positional embeddings
  • Transformer Blocks: Multi-head self-attention + feed-forward networks
  • Parameters: ~10M (small), ~50M (medium)
  • Context Length: 512 tokens
  • Vocabulary: Custom SentencePiece tokenizer trained on your data

Configuration

Model Sizes

from gptmed.model.configs.model_config import (
    get_tiny_config,   # ~2M parameters - for testing
    get_small_config,  # ~10M parameters - recommended
    get_medium_config  # ~50M parameters - higher quality
)

Training Configuration

from gptmed.configs.train_config import TrainingConfig

config = TrainingConfig(
    batch_size=16,
    learning_rate=3e-4,
    num_epochs=10,
    warmup_steps=100,
    grad_clip=1.0
)

Observability

New in v0.4.0: Built-in training monitoring with Observer Pattern architecture.

Features

  • ๐Ÿ“Š Loss Curves: Track training/validation loss over time
  • ๐Ÿ“ˆ Metrics Tracking: Perplexity, gradient norms, learning rates
  • ๐Ÿ”” Callbacks: Console output, JSON logging, early stopping
  • ๐Ÿ“ Export: CSV export, matplotlib visualizations
  • ๐Ÿ”Œ Extensible: Add custom observers for integrations (W&B, TensorBoard)

Quick Example

from gptmed.observability import MetricsTracker, ConsoleCallback, EarlyStoppingCallback

# Create observers
tracker = MetricsTracker(output_dir='./metrics')
console = ConsoleCallback(print_every=50)
early_stop = EarlyStoppingCallback(patience=3)

# Use with TrainingService (automatic)
from gptmed.services import TrainingService
service = TrainingService(config_path='config.yaml')
service.train()  # Automatically creates MetricsTracker

# Or use with Trainer directly
trainer = Trainer(model, train_loader, config, observers=[tracker, console])
trainer.train()

Available Observers

Observer Description
MetricsTracker Comprehensive metrics collection with export capabilities
ConsoleCallback Real-time console output with progress bars
JSONLoggerCallback Structured JSON logging for analysis
EarlyStoppingCallback Stop training when validation loss plateaus
LRSchedulerCallback Learning rate scheduling integration

See XAI.md for future Explainable AI features roadmap.

Project Structure

gptmed/
โ”œโ”€โ”€ model/
โ”‚   โ”œโ”€โ”€ architecture/      # GPT transformer implementation
โ”‚   โ””โ”€โ”€ configs/           # Model configurations
โ”œโ”€โ”€ inference/
โ”‚   โ”œโ”€โ”€ generator.py       # Text generation
โ”‚   โ””โ”€โ”€ sampling.py        # Sampling strategies
โ”œโ”€โ”€ training/
โ”‚   โ”œโ”€โ”€ train.py          # Training script
โ”‚   โ”œโ”€โ”€ trainer.py        # Training loop
โ”‚   โ””โ”€โ”€ dataset.py        # Data loading
โ”œโ”€โ”€ observability/         # Training monitoring & XAI (v0.4.0+)
โ”‚   โ”œโ”€โ”€ base.py           # Observer pattern interfaces
โ”‚   โ”œโ”€โ”€ metrics_tracker.py # Loss curves & metrics
โ”‚   โ””โ”€โ”€ callbacks.py      # Console, JSON, early stopping
โ”œโ”€โ”€ tokenizer/
โ”‚   โ””โ”€โ”€ train_tokenizer.py # SentencePiece tokenizer
โ”œโ”€โ”€ configs/
โ”‚   โ””โ”€โ”€ train_config.py   # Training configurations
โ”œโ”€โ”€ services/
โ”‚   โ””โ”€โ”€ training_service.py # High-level training orchestration
โ””โ”€โ”€ utils/
    โ”œโ”€โ”€ checkpoints.py    # Model checkpointing
    โ””โ”€โ”€ logging.py        # Training logging

Requirements

  • Python >= 3.8
  • PyTorch >= 2.0.0
  • sentencepiece >= 0.1.99
  • numpy >= 1.24.0
  • tqdm >= 4.65.0

Documentation

๐Ÿ“š Complete User Manual - Step-by-step guide for training your own model

Quick Links

Performance

Model Size Parameters Training Time Inference Speed
Tiny ~2M 2 hours ~100 tokens/sec
Small ~10M 8 hours ~80 tokens/sec
Medium ~50M 24 hours ~50 tokens/sec

Tested on GTX 1080 8GB

Examples

Domain-Agnostic Usage

GptMed works with any domain - just train on your own Q&A data:

# Technical Support Bot
question = "How do I reset my WiFi router?"
answer = generator.generate(question, temperature=0.7)

# Educational Assistant
question = "Explain the water cycle in simple terms"
answer = generator.generate(question, temperature=0.6)

# Customer Service
question = "What is your return policy?"
answer = generator.generate(question, temperature=0.5)

# Medical Q&A (example domain)
question = "What are the symptoms of flu?"
answer = generator.generate(question, temperature=0.7)

Training Observability (v0.4.0+)

Monitor your training with built-in observability:

from gptmed.observability import MetricsTracker, ConsoleCallback

# Create observers
tracker = MetricsTracker(output_dir='./metrics')
console = ConsoleCallback(print_every=10)

# Train with observability
gptmed.train_from_config(
    'my_config.yaml',
    observers=[tracker, console]
)

# After training - get the report
report = tracker.get_report()
print(f"Final Loss: {report['final_loss']:.4f}")
print(f"Total Steps: {report['total_steps']}")

# Export metrics
tracker.export_to_csv('training_metrics.csv')
tracker.plot_loss_curves('loss_curves.png')  # Requires matplotlib

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

  1. Fork the repository
  2. Create your feature branch (git checkout -b feature/AmazingFeature)
  3. Commit your changes (git commit -m 'Add some AmazingFeature')
  4. Push to the branch (git push origin feature/AmazingFeature)
  5. Open a Pull Request

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

  • MedQuAD dataset creators
  • PyTorch team

Support

Changelog

Full Changelog


Made with โค๏ธ from Nepal

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gptmed-0.8.6.tar.gz (157.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gptmed-0.8.6-py3-none-any.whl (196.0 kB view details)

Uploaded Python 3

File details

Details for the file gptmed-0.8.6.tar.gz.

File metadata

  • Download URL: gptmed-0.8.6.tar.gz
  • Upload date:
  • Size: 157.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.19

File hashes

Hashes for gptmed-0.8.6.tar.gz
Algorithm Hash digest
SHA256 5e600b61f0876e4d3cb094cf882cdf56b4e901469351c877113c754efdf43dc3
MD5 fcafa8e35a983f1085c513e56e151c38
BLAKE2b-256 8c33fd0f0c374f69c38f779f70ded8b8076ed5f31a6c77ef02eecae981cf80ee

See more details on using hashes here.

File details

Details for the file gptmed-0.8.6-py3-none-any.whl.

File metadata

  • Download URL: gptmed-0.8.6-py3-none-any.whl
  • Upload date:
  • Size: 196.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.19

File hashes

Hashes for gptmed-0.8.6-py3-none-any.whl
Algorithm Hash digest
SHA256 58589f296cb22e7eecee491f4a1c6017e7c0b1665b0dfda4b09c0153979120cf
MD5 938af55f43c761bb9a1e39dbccb82c88
BLAKE2b-256 cc273157bf581eb2ab800e5b9e7d9f9644d8c3539588e152fa638b8a5ffae0e6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page