Skip to main content

Multimodal extension of Google's TimesFM for time series forecasting with text

Project description

Multimodal TimesFM

A multimodal extension of Google's TimesFM for time series forecasting with text inputs.

Installation

pip install multimodal-timesfm[all]

Quick Start

1. Define Custom Dataset

Implement a custom dataset by extending MultimodalDatasetBase:

from pathlib import Path
from typing import Literal
import numpy as np
from multimodal_timesfm.multimodal_dataset import MultimodalDatasetBase

class CustomDataset(MultimodalDatasetBase):
    """Custom dataset for your multimodal time series data."""

    def __init__(
        self,
        data_dir: Path,
        split_ratio: float = 0.8,
        split: Literal["train", "test"] = "train",
        patch_len: int = 32,
        context_len: int = 128,
        horizon_len: int = 32,
    ):
        super().__init__(data_dir, split_ratio, split, patch_len, context_len, horizon_len)

    def _load_data(self) -> None:
        """Load and process your custom data format.

        Populate self.data with dictionaries containing:
        - context: np.ndarray of shape (context_len,) - historical time series values
        - future: np.ndarray of shape (horizon_len,) - target future values
        - freq: int - frequency indicator (0=daily, 1=weekly/monthly, 2=quarterly+)
        - patched_texts: list of lists - text organized by temporal patches
        - metadata: dict - additional sample information
        """
        # Your custom data loading logic
        for sample in self._read_your_data_files():
            # Organize texts by patches (one list per patch)
            num_patches = self.context_len // self.patch_len
            patched_texts = [[] for _ in range(num_patches)]

            # Assign text descriptions to appropriate patches
            for text_item in sample["texts"]:
                patch_idx = self._get_patch_index(text_item["timestamp"])
                patched_texts[patch_idx].append(text_item["text"])

            self.data.append({
                "context": sample["historical_values"],  # shape: (context_len,)
                "future": sample["target_values"],       # shape: (horizon_len,)
                "freq": sample["frequency"],             # 0, 1, or 2
                "patched_texts": patched_texts,          # list of text lists
                "metadata": sample["info"]
            })

2. Train the Model

from multimodal_timesfm.multimodal_patched_decoder import MultimodalPatchedDecoder, MultimodalTimesFMConfig
from multimodal_timesfm.trainer import MultimodalTrainer

# Create datasets
train_dataset = CustomDataset(data_dir="path/to/data", split="train")
val_dataset = CustomDataset(data_dir="path/to/data", split="test")

# Initialize model with pretrained TimesFM weights
config = MultimodalTimesFMConfig(
    text_encoder_type="english",  # or "japanese"
    context_len=128,
    horizon_len=32,
    input_patch_len=32,
)
model = MultimodalPatchedDecoder(config)

# Load pretrained TimesFM checkpoint
model.load_pretrained_timesfm("path/to/timesfm_checkpoint.ckpt")

# Train
trainer = MultimodalTrainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    weight_decay=0.01,
    log_dir=Path("logs"),
    checkpoint_dir=Path("checkpoints"),
    wandb_project="my-project",
    wandb_run_name="experiment-1"
)

# Train for 10 epochs, saving checkpoints every 5 epochs
trainer.train(num_epochs=10, save_frequency=5)

3. Evaluate the Model

from multimodal_timesfm.evaluation import evaluate_multimodal_model
from torch.utils.data import DataLoader
from multimodal_timesfm.utils.collate import multimodal_collate_fn

# Create test dataset
test_dataset = CustomDataset(data_dir="path/to/data", split="test")
test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=multimodal_collate_fn
)

# Load trained model
model = MultimodalPatchedDecoder(config)
checkpoint = torch.load("checkpoints/best_model.pt")
model.load_state_dict(checkpoint["model_state_dict"])

# Evaluate
metrics = evaluate_multimodal_model(model, test_loader, device="cuda")
print(f"MSE: {metrics['mse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")

4. Visualize Predictions

import matplotlib.pyplot as plt
import numpy as np

# Get predictions for visualization
model.eval()
sample = test_dataset[0]
with torch.no_grad():
    prediction = model.forecast(
        context=torch.tensor(sample["context"]).unsqueeze(0),
        text_inputs=[sample["patched_texts"]],
        freq=torch.tensor([sample["freq"]])
    )

# Plot
plt.figure(figsize=(12, 4))
context_len = len(sample["context"])
horizon_len = len(sample["future"])

# Plot context
plt.plot(range(context_len), sample["context"], label="Context", color="blue")

# Plot ground truth
plt.plot(range(context_len, context_len + horizon_len),
         sample["future"], label="Ground Truth", color="green")

# Plot prediction
plt.plot(range(context_len, context_len + horizon_len),
         prediction[0].cpu().numpy(), label="Prediction", color="red", linestyle="--")

plt.xlabel("Time")
plt.ylabel("Value")
plt.legend()
plt.title("Multimodal TimesFM Forecast")
plt.savefig("forecast_visualization.png")

5. Inference on New Data

from multimodal_timesfm import MultimodalTimesFM, TimesFmHparams, MultimodalTimesFMConfig

# Load trained model for inference
hparams = TimesFmHparams(context_len=128, horizon_len=32)
config = MultimodalTimesFMConfig(text_encoder_type="english")
model = MultimodalTimesFM(hparams, config, "checkpoints/best_model.pt")

# Prepare new data
time_series_data = np.array([...])  # Your time series context
text_descriptions = [[
    ["High volatility expected"],        # Texts for patch 1
    ["Market uncertainty increasing"],   # Texts for patch 2
    ["Economic indicators show growth"]  # Texts for patch 3
]]

# Generate forecast
forecasts, quantiles = model.forecast(
    inputs=[time_series_data],
    text_descriptions=text_descriptions,
    freq=[0],  # 0=daily, 1=weekly/monthly, 2=quarterly+
    forecast_context_len=128
)

print(f"Forecast shape: {forecasts.shape}")
print(f"Point forecast: {forecasts[0]}")
print(f"Quantiles shape: {quantiles.shape}")

Features

  • Multimodal forecasting: Combines time series data with textual context
  • Built on TimesFM: Leverages Google's state-of-the-art time series foundation model
  • Flexible text encoding: Supports English and Japanese text inputs
  • Easy integration: Simple API for adding text context to time series forecasting

Time-MMD Dataset Example

The project includes complete scripts for training and evaluating on the Time-MMD dataset.

Setup

Initialize the Time-MMD dataset submodule:

git submodule update --init

The dataset contains multimodal time series data across 10 domains: Agriculture, Climate, Economy, Energy, Environment, Health_AFR, Health_US, Security, SocialGood, and Traffic.

Training with Cross-Validation

Train a multimodal TimesFM model using cross-validation:

# Train both multimodal and fine-tuned baseline (recommended)
PYTHONPATH=. uv run python scripts/train_time_mmd_cv.py \
    --train-baseline \
    --seed 42

# Train only multimodal model
PYTHONPATH=. uv run python scripts/train_time_mmd_cv.py \
    --seed 42

Configuration:

  • Model config: TimesFM architecture parameters (layers, dimensions, context length, etc.)
  • Training config: Batch size, learning rate, domains, cross-validation settings, etc.
  • See examples/time_mmd/configs/ for configuration templates

The scripts will:

  • Create train/validation splits for each cross-validation fold
  • Train a separate model for each fold
  • Save checkpoints and cross-validation results to JSON

Evaluation

Evaluate trained models on the test set:

# Evaluate multimodal model only
PYTHONPATH=. uv run python scripts/evaluate_time_mmd_cv.py \
    --cv-results logs/cv_results.json \
    --seed 42

# Compare with pretrained baseline (no fine-tuning)
PYTHONPATH=. uv run python scripts/evaluate_time_mmd_cv.py \
    --cv-results logs/cv_results.json \
    --compare-baseline \
    --seed 42

# Compare with fine-tuned baseline
PYTHONPATH=. uv run python scripts/evaluate_time_mmd_cv.py \
    --cv-results logs/cv_results.json \
    --baseline-cv-results logs/baseline_finetuned_cv_results.json \
    --seed 42

# Compare with ARIMA baseline
PYTHONPATH=. uv run python scripts/evaluate_time_mmd_cv.py \
    --cv-results logs/cv_results.json \
    --compare-arima \
    --seed 42

# Compare with ARIMA using custom order (p, d, q)
PYTHONPATH=. uv run python scripts/evaluate_time_mmd_cv.py \
    --cv-results logs/cv_results.json \
    --compare-arima \
    --arima-order 7 1 2 \
    --seed 42

# Compare all models (recommended)
PYTHONPATH=. uv run python scripts/evaluate_time_mmd_cv.py \
    --cv-results logs/cv_results.json \
    --compare-baseline \
    --baseline-cv-results logs/baseline_finetuned_cv_results.json \
    --compare-arima \
    --seed 42

This evaluates and compares up to four model variants:

  1. Multimodal model (always evaluated): TimesFM with text encoder and fusion layer
  2. Pretrained baseline (with --compare-baseline): TimesFM without fine-tuning (untrained, frozen weights)
  3. Fine-tuned baseline (with --baseline-cv-results): TimesFM fine-tuned on time series only (no text)
  4. ARIMA baseline (with --compare-arima): Traditional ARIMA model with configurable order (default: 32, 1, 1)

You can compare all models at once by providing --compare-baseline, --baseline-cv-results, and --compare-arima flags together.

Metrics reported:

  • Mean Squared Error (MSE)
  • Mean Absolute Error (MAE)
  • Per-fold and overall metrics

Visualization

Visualize model predictions:

PYTHONPATH=. uv run python scripts/visualize_time_mmd_cv.py \
    --cv-results logs/cv_results.json

Output:

  • Time series plots showing context, ground truth, and predictions
  • Metric comparison bar charts (MSE and MAE)

Forecasting with Custom Parameters

Generate forecasts with manually configurable context and horizon lengths:

# Forecast on all folds
PYTHONPATH=. uv run python scripts/forecast_time_mmd.py \
    --cv-results logs/cv_results.json \
    --context-len 512 \
    --horizon-len 128

# Forecast on a specific fold
PYTHONPATH=. uv run python scripts/forecast_time_mmd.py \
    --cv-results logs/cv_results.json \
    --fold 0 \
    --context-len 512 \
    --horizon-len 128

# Forecast with custom settings
PYTHONPATH=. uv run python scripts/forecast_time_mmd.py \
    --cv-results logs/cv_results.json \
    --context-len 256 \
    --horizon-len 64 \
    --num-samples 10 \
    --output-dir custom_plots

This script compares multimodal model forecasts against baseline TimesFM forecasts, providing:

  • Time series plots comparing multimodal vs baseline predictions
  • Bar charts comparing MSE and MAE metrics
  • JSON output with all forecasts and metrics

Acknowledgments

We thank the Time-MMD team for providing the multimodal time series dataset used in our examples and experiments.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

multimodal_timesfm-0.1.4.tar.gz (29.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

multimodal_timesfm-0.1.4-py3-none-any.whl (37.2 kB view details)

Uploaded Python 3

File details

Details for the file multimodal_timesfm-0.1.4.tar.gz.

File metadata

  • Download URL: multimodal_timesfm-0.1.4.tar.gz
  • Upload date:
  • Size: 29.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for multimodal_timesfm-0.1.4.tar.gz
Algorithm Hash digest
SHA256 e8cae5e945a51b483c30a9b2c08a09c79d1ede4569d5cdaddb1424bddb78633e
MD5 59a0789aa121342c7df2a2e6ba057d72
BLAKE2b-256 2c1ace1829883d1d6fb7dbde57154fc9b723dffcffd6910814905582a9818690

See more details on using hashes here.

Provenance

The following attestation bundles were made for multimodal_timesfm-0.1.4.tar.gz:

Publisher: publish-to-pypi.yml on himura467/multimodal-timesfm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file multimodal_timesfm-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for multimodal_timesfm-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 a237a7b45e653e0859838712a9459e744e6b08c9fc93c7be6ba8b3e4efdcfb80
MD5 ca8df90cb4234d758af771b2302daccf
BLAKE2b-256 30c15d654131e642e58e5c37eee2a8fdaced9e9f2672ca34c757d96481407777

See more details on using hashes here.

Provenance

The following attestation bundles were made for multimodal_timesfm-0.1.4-py3-none-any.whl:

Publisher: publish-to-pypi.yml on himura467/multimodal-timesfm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page