Skip to main content

Multi-Task Attention-based Transformer for Sequential API Recommendation

Project description

Context Engineer

PyPI version License: MIT Python 3.9+

A reproducible research package for the paper:

Rethink Context Engineering Using an Attention-based Architecture Yiqiao Yin — University of Chicago Booth School of Business / Columbia University

This package implements a multi-task attention-based transformer for sequential API recommendation. It simultaneously predicts the next API action, session goal, and session boundary from user interaction sequences modeled as Markov chains.

Key Results (from the paper)

Metric Value
API Prediction Accuracy (Top-1) 79.83%
Top-5 Hit Rate 99.97%
Top-10 Hit Rate 100.00%
Goal Prediction Accuracy 81.6%
Session End Accuracy 99.3%
Improvement over Markov baseline +432%

Installation

pip install context-engineer

For visualization support:

pip install context-engineer[viz]

Quick Start

Reproduce the Full Experiment

from context_engineer import run_pipeline

# Run with paper defaults: 2000 users, 100 APIs, 60 epochs
results = run_pipeline()

# Access results
model = results["model"]        # Trained PyTorch model
metrics = results["metrics"]    # Evaluation metrics dict
history = results["training_history"]

Generate Simulation Data Only

from context_engineer import simulate_multitask_markov_data

# Generate user session logs with 4 persona types
sequences, goals = simulate_multitask_markov_data(
    num_users=500,
    num_apis=100,
    clicks_per_user=10,
)
# sequences: list of API call sequences per user
# goals: list of session goal labels (0-3)

Train with Custom Parameters

from context_engineer import run_pipeline

results = run_pipeline(
    num_users=1000,
    num_epochs=30,
    embed_dim=64,
    num_heads=4,
    learning_rate=0.001,
    seed=123,
)

Command-Line Interface

# Run the full pipeline
context-engineer run

# Custom parameters
context-engineer run --num-users 1000 --epochs 30 --seed 123

# Generate data only (outputs JSON)
context-engineer generate --num-users 500 --output data.json

How It Works

The Problem

In enterprise platforms, users interact with APIs in sequential patterns. Different users have different goals (personas), and their API call patterns reflect those goals. If you can predict what a user will do next, you can provide better recommendations, prefetch resources, or optimize the platform.

Training Data Format

The model expects data in two lists:

  • sequences: A list of API call sequences, where each sequence is a list of integer IDs (e.g., [[0, 20, 30, 40, 50], [1, 25, 35, 65], ...])
  • goals: A list of session goal labels, one per sequence (e.g., [0, 1, 2, 0, ...])

The function create_multitask_training_pairs() then slides a window through each session to produce supervised examples. For a session like [0, 20, 30, 40, 50] with goal 0:

Input Sequence Target: Next API Target: Goal Target: Session End
[0] 20 0 0
[0, 20] 30 0 0
[0, 20, 30] 40 0 0
[0, 20, 30, 40] 50 0 1

The model learns to predict all three targets simultaneously from the input sequence alone. The goal and session-end labels are supervision signals — they teach the auxiliary heads during training.

Inference

At inference time, you only provide the sequence of API calls observed so far. The model returns all three predictions simultaneously:

  1. Probability distribution over all APIs — e.g., API 40 has 72% probability, API 45 has 15%, etc. Pick the top-K as recommendations.
  2. Probability distribution over goal types — e.g., 85% ml_pipeline, 10% data_analysis, etc. This tells you what the user is trying to do.
  3. Session end probability — e.g., 0.03 (3% chance the user is about to leave).

You do not need the goal label at inference time — the model infers it. From just observing a user's actions, the model tells you what they'll do next, what they're trying to accomplish, and whether they're about to leave.

Architecture

The model uses a shared transformer encoder with three task-specific prediction heads:

Input Sequence [API_1, API_2, ..., API_t]
        |
   Embedding + Positional Encoding
        |
   Transformer Encoder (3 layers, 8 heads)
        |
   Shared Feature Representation
       /|\
      / | \
     /  |  \
    v   v   v
 Next  Goal  Session
 API   Head  End Head
 Head
  • Primary task: Next API prediction (100-class classification)
  • Auxiliary task 1: Session goal classification (4 classes)
  • Auxiliary task 2: Session end detection (binary)
  • Loss: Weighted combination (1.0 / 0.3 / 0.2)

Bring Your Own Data

Every company has its own schema and naming conventions. To use this package with your own user log data, you just need to map your logs into the format above — integer-encoded API sequences and goal labels. The rest of the pipeline handles everything:

from context_engineer import (
    create_multitask_training_pairs,
    MultiTaskMarkovDataset,
    MultiTaskMarkovAPIRecommender,
    train_multitask_model,
    evaluate_multitask_model,
    set_random_seeds,
)
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

set_random_seeds(42)

# Your real data (integer-encoded)
sequences = [[0, 5, 12, 30, 50], [1, 8, 20, 45], ...]
goals = [2, 0, 1, ...]

num_apis = 60   # total unique API IDs in your data
num_goals = 3   # total unique goal types in your data

# Create training pairs
input_seqs, targets, goal_labels, end_labels = create_multitask_training_pairs(
    sequences, goals, max_seq_len=6
)

# Split
data = list(zip(input_seqs, targets, goal_labels, end_labels))
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_seqs, train_tgt, train_g, train_e = zip(*train_data)
test_seqs, test_tgt, test_g, test_e = zip(*test_data)

# Build loaders
train_loader = DataLoader(
    MultiTaskMarkovDataset(train_seqs, train_tgt, train_g, train_e, max_seq_len=6),
    batch_size=128, shuffle=True,
)
test_loader = DataLoader(
    MultiTaskMarkovDataset(test_seqs, test_tgt, test_g, test_e, max_seq_len=6),
    batch_size=128,
)

# Train
model = MultiTaskMarkovAPIRecommender(num_apis=num_apis, num_goals=num_goals)
history = train_multitask_model(model, train_loader, test_loader, num_epochs=30)

# Evaluate
metrics = evaluate_multitask_model(model, test_loader)

Dataset Design

The simulator generates realistic API usage patterns with:

  • 100 APIs across 10 functional categories (Auth, User Mgmt, Data Input, Data Processing, ML Training, ML Prediction, Basic Viz, Advanced Viz, Export, Admin)
  • 4 session goals: ML Pipeline (85% workflow adherence), Data Analysis (80%), User Management (90%), Quick Visualization (75%)
  • Markov chain transitions with configurable, high-probability workflow patterns
  • Configurable simulator: Override transition probabilities, starting probabilities, and goal distributions to match your domain

Package Structure

src/context_engineer/
    __init__.py    # Public API
    data.py        # Data simulation & dataset classes
    model.py       # Transformer model architecture
    train.py       # Training, evaluation, inference
    pipeline.py    # End-to-end pipeline
    utils.py       # Model save/load, data export (JSON/CSV)
    viz.py         # Publication-ready visualization functions
    cli.py         # Command-line interface
    py.typed       # PEP 561 type checker marker

API Reference

Core Functions

Function Description
run_pipeline(**kwargs) Run the full experiment end-to-end
simulate_multitask_markov_data(...) Generate simulated user sessions
create_multitask_training_pairs(...) Convert sequences to supervised pairs
train_multitask_model(...) Train with early stopping and cosine annealing
evaluate_multitask_model(...) Evaluate with accuracy, MRR, Hit Rate@K
set_random_seeds(seed) Set seeds for reproducibility
save_model(model, path) Save trained model checkpoint to disk
load_model(path) Load model from checkpoint (reconstructs architecture)
export_metrics(metrics, path) Export evaluation metrics to JSON or CSV
export_training_history(history, path) Export training curves to CSV
export_sequences(sequences, goals, path) Export session data to JSON or CSV

Visualization Functions (requires pip install context-engineer[viz])

Function Description
plot_data_overview(...) 6-panel Markov chain data analysis (Paper Figure 1)
plot_api_transitions(...) Per-goal transition heatmaps (Paper Figure 2)
plot_training_history(...) 6-panel training progress (Paper Figure 5)
plot_evaluation(...) Evaluation dashboard with baselines (Paper Figure 6)

Core Classes

Class Description
MultiTaskMarkovChainAPISimulator Configurable Markov chain data generator
MultiTaskMarkovAPIRecommender Multi-task transformer model
MultiTaskMarkovDataset PyTorch Dataset for training data

Citation

If you use this package in your research, please cite:

@article{yin2025rethink,
  title={Rethink Context Engineering Using an Attention-based Architecture},
  author={Yin, Yiqiao},
  year={2025}
}

License

MIT License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

context_engineer-0.2.1.tar.gz (4.5 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

context_engineer-0.2.1-py3-none-any.whl (29.3 kB view details)

Uploaded Python 3

File details

Details for the file context_engineer-0.2.1.tar.gz.

File metadata

  • Download URL: context_engineer-0.2.1.tar.gz
  • Upload date:
  • Size: 4.5 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.7.6

File hashes

Hashes for context_engineer-0.2.1.tar.gz
Algorithm Hash digest
SHA256 f304efab02081167a960b20a22f03bdecedefefe52e440cb5386f8b2fb783e66
MD5 83865933cd38f185d014819b507c6bfd
BLAKE2b-256 2fa66f86afb2485138c7cfa87dccd9f06915046c1263ef24cf34d588463c8f9f

See more details on using hashes here.

File details

Details for the file context_engineer-0.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for context_engineer-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7a258e69a6d70fc05ebf8136b99b23a683d20922f85456075c5868e110816f68
MD5 ac6f9f331f8217c545e12cd17e60b84f
BLAKE2b-256 01ca676178c052b46f8a2ea58ac68336186728c97f059a5c1df97031f132fada

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page