Multi-Task Attention-based Transformer for Sequential API Recommendation
Project description
Context Engineer
A reproducible research package for the paper:
Rethink Context Engineering Using an Attention-based Architecture Yiqiao Yin — University of Chicago Booth School of Business / Columbia University
This package implements a multi-task attention-based transformer for sequential API recommendation. It simultaneously predicts the next API action, session goal, and session boundary from user interaction sequences modeled as Markov chains.
Key Results (from the paper)
| Metric | Value |
|---|---|
| API Prediction Accuracy (Top-1) | 79.83% |
| Top-5 Hit Rate | 99.97% |
| Top-10 Hit Rate | 100.00% |
| Goal Prediction Accuracy | 81.6% |
| Session End Accuracy | 99.3% |
| Improvement over Markov baseline | +432% |
Installation
pip install context-engineer
For visualization support:
pip install context-engineer[viz]
Quick Start
Reproduce the Full Experiment
from context_engineer import run_pipeline
# Run with paper defaults: 2000 users, 100 APIs, 60 epochs
results = run_pipeline()
# Access results
model = results["model"] # Trained PyTorch model
metrics = results["metrics"] # Evaluation metrics dict
history = results["training_history"]
Generate Simulation Data Only
from context_engineer import simulate_multitask_markov_data
# Generate user session logs with 4 persona types
sequences, goals = simulate_multitask_markov_data(
num_users=500,
num_apis=100,
clicks_per_user=10,
)
# sequences: list of API call sequences per user
# goals: list of session goal labels (0-3)
Train with Custom Parameters
from context_engineer import run_pipeline
results = run_pipeline(
num_users=1000,
num_epochs=30,
embed_dim=64,
num_heads=4,
learning_rate=0.001,
seed=123,
)
Command-Line Interface
# Run the full pipeline
context-engineer run
# Custom parameters
context-engineer run --num-users 1000 --epochs 30 --seed 123
# Generate data only (outputs JSON)
context-engineer generate --num-users 500 --output data.json
How It Works
The Problem
In enterprise platforms, users interact with APIs in sequential patterns. Different users have different goals (personas), and their API call patterns reflect those goals. If you can predict what a user will do next, you can provide better recommendations, prefetch resources, or optimize the platform.
Training Data Format
The model expects data in two lists:
sequences: A list of API call sequences, where each sequence is a list of integer IDs (e.g.,[[0, 20, 30, 40, 50], [1, 25, 35, 65], ...])goals: A list of session goal labels, one per sequence (e.g.,[0, 1, 2, 0, ...])
The function create_multitask_training_pairs() then slides a window through each session to produce supervised examples. For a session like [0, 20, 30, 40, 50] with goal 0:
| Input Sequence | Target: Next API | Target: Goal | Target: Session End |
|---|---|---|---|
[0] |
20 |
0 |
0 |
[0, 20] |
30 |
0 |
0 |
[0, 20, 30] |
40 |
0 |
0 |
[0, 20, 30, 40] |
50 |
0 |
1 |
The model learns to predict all three targets simultaneously from the input sequence alone. The goal and session-end labels are supervision signals — they teach the auxiliary heads during training.
Inference
At inference time, you only provide the sequence of API calls observed so far. The model returns all three predictions simultaneously:
- Probability distribution over all APIs — e.g., API 40 has 72% probability, API 45 has 15%, etc. Pick the top-K as recommendations.
- Probability distribution over goal types — e.g., 85%
ml_pipeline, 10%data_analysis, etc. This tells you what the user is trying to do. - Session end probability — e.g., 0.03 (3% chance the user is about to leave).
You do not need the goal label at inference time — the model infers it. From just observing a user's actions, the model tells you what they'll do next, what they're trying to accomplish, and whether they're about to leave.
Architecture
The model uses a shared transformer encoder with three task-specific prediction heads:
Input Sequence [API_1, API_2, ..., API_t]
|
Embedding + Positional Encoding
|
Transformer Encoder (3 layers, 8 heads)
|
Shared Feature Representation
/|\
/ | \
/ | \
v v v
Next Goal Session
API Head End Head
Head
- Primary task: Next API prediction (100-class classification)
- Auxiliary task 1: Session goal classification (4 classes)
- Auxiliary task 2: Session end detection (binary)
- Loss: Weighted combination (1.0 / 0.3 / 0.2)
Bring Your Own Data
Every company has its own schema and naming conventions. To use this package with your own user log data, you just need to map your logs into the format above — integer-encoded API sequences and goal labels. The rest of the pipeline handles everything:
from context_engineer import (
create_multitask_training_pairs,
MultiTaskMarkovDataset,
MultiTaskMarkovAPIRecommender,
train_multitask_model,
evaluate_multitask_model,
set_random_seeds,
)
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
set_random_seeds(42)
# Your real data (integer-encoded)
sequences = [[0, 5, 12, 30, 50], [1, 8, 20, 45], ...]
goals = [2, 0, 1, ...]
num_apis = 60 # total unique API IDs in your data
num_goals = 3 # total unique goal types in your data
# Create training pairs
input_seqs, targets, goal_labels, end_labels = create_multitask_training_pairs(
sequences, goals, max_seq_len=6
)
# Split
data = list(zip(input_seqs, targets, goal_labels, end_labels))
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_seqs, train_tgt, train_g, train_e = zip(*train_data)
test_seqs, test_tgt, test_g, test_e = zip(*test_data)
# Build loaders
train_loader = DataLoader(
MultiTaskMarkovDataset(train_seqs, train_tgt, train_g, train_e, max_seq_len=6),
batch_size=128, shuffle=True,
)
test_loader = DataLoader(
MultiTaskMarkovDataset(test_seqs, test_tgt, test_g, test_e, max_seq_len=6),
batch_size=128,
)
# Train
model = MultiTaskMarkovAPIRecommender(num_apis=num_apis, num_goals=num_goals)
history = train_multitask_model(model, train_loader, test_loader, num_epochs=30)
# Evaluate
metrics = evaluate_multitask_model(model, test_loader)
Dataset Design
The simulator generates realistic API usage patterns with:
- 100 APIs across 10 functional categories (Auth, User Mgmt, Data Input, Data Processing, ML Training, ML Prediction, Basic Viz, Advanced Viz, Export, Admin)
- 4 session goals: ML Pipeline (85% workflow adherence), Data Analysis (80%), User Management (90%), Quick Visualization (75%)
- Markov chain transitions with configurable, high-probability workflow patterns
- Configurable simulator: Override transition probabilities, starting probabilities, and goal distributions to match your domain
Package Structure
src/context_engineer/
__init__.py # Public API
data.py # Data simulation & dataset classes
model.py # Transformer model architecture
train.py # Training, evaluation, inference
pipeline.py # End-to-end pipeline
utils.py # Model save/load, data export (JSON/CSV)
viz.py # Publication-ready visualization functions
cli.py # Command-line interface
py.typed # PEP 561 type checker marker
API Reference
Core Functions
| Function | Description |
|---|---|
run_pipeline(**kwargs) |
Run the full experiment end-to-end |
simulate_multitask_markov_data(...) |
Generate simulated user sessions |
create_multitask_training_pairs(...) |
Convert sequences to supervised pairs |
train_multitask_model(...) |
Train with early stopping and cosine annealing |
evaluate_multitask_model(...) |
Evaluate with accuracy, MRR, Hit Rate@K |
set_random_seeds(seed) |
Set seeds for reproducibility |
save_model(model, path) |
Save trained model checkpoint to disk |
load_model(path) |
Load model from checkpoint (reconstructs architecture) |
export_metrics(metrics, path) |
Export evaluation metrics to JSON or CSV |
export_training_history(history, path) |
Export training curves to CSV |
export_sequences(sequences, goals, path) |
Export session data to JSON or CSV |
Visualization Functions (requires pip install context-engineer[viz])
| Function | Description |
|---|---|
plot_data_overview(...) |
6-panel Markov chain data analysis (Paper Figure 1) |
plot_api_transitions(...) |
Per-goal transition heatmaps (Paper Figure 2) |
plot_training_history(...) |
6-panel training progress (Paper Figure 5) |
plot_evaluation(...) |
Evaluation dashboard with baselines (Paper Figure 6) |
Core Classes
| Class | Description |
|---|---|
MultiTaskMarkovChainAPISimulator |
Configurable Markov chain data generator |
MultiTaskMarkovAPIRecommender |
Multi-task transformer model |
MultiTaskMarkovDataset |
PyTorch Dataset for training data |
Citation
If you use this package in your research, please cite:
@article{yin2025rethink,
title={Rethink Context Engineering Using an Attention-based Architecture},
author={Yin, Yiqiao},
year={2025}
}
License
MIT License. See LICENSE for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file context_engineer-0.2.1.tar.gz.
File metadata
- Download URL: context_engineer-0.2.1.tar.gz
- Upload date:
- Size: 4.5 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.7.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f304efab02081167a960b20a22f03bdecedefefe52e440cb5386f8b2fb783e66
|
|
| MD5 |
83865933cd38f185d014819b507c6bfd
|
|
| BLAKE2b-256 |
2fa66f86afb2485138c7cfa87dccd9f06915046c1263ef24cf34d588463c8f9f
|
File details
Details for the file context_engineer-0.2.1-py3-none-any.whl.
File metadata
- Download URL: context_engineer-0.2.1-py3-none-any.whl
- Upload date:
- Size: 29.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.7.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7a258e69a6d70fc05ebf8136b99b23a683d20922f85456075c5868e110816f68
|
|
| MD5 |
ac6f9f331f8217c545e12cd17e60b84f
|
|
| BLAKE2b-256 |
01ca676178c052b46f8a2ea58ac68336186728c97f059a5c1df97031f132fada
|