Skip to main content

Multimodal extension of Google's TimesFM for time series forecasting with text

Project description

Multimodal TimesFM

A multimodal extension of Google's TimesFM for time series forecasting with text inputs.

Installation

pip install multimodal-timesfm[all]

Quick Start

1. Define Custom Dataset

Implement a custom dataset by extending MultimodalDatasetBase:

from pathlib import Path
from typing import Literal
import numpy as np
from multimodal_timesfm.multimodal_dataset import MultimodalDatasetBase

class CustomDataset(MultimodalDatasetBase):
    """Custom dataset for your multimodal time series data."""

    def __init__(
        self,
        data_dir: Path,
        split_ratio: float = 0.8,
        split: Literal["train", "test"] = "train",
        patch_len: int = 32,
        context_len: int = 128,
        horizon_len: int = 32,
    ):
        super().__init__(data_dir, split_ratio, split, patch_len, context_len, horizon_len)

    def _load_data(self) -> None:
        """Load and process your custom data format.

        Populate self.data with dictionaries containing:
        - context: np.ndarray of shape (context_len,) - historical time series values
        - future: np.ndarray of shape (horizon_len,) - target future values
        - freq: int - frequency indicator (0=daily, 1=weekly/monthly, 2=quarterly+)
        - patched_texts: list of lists - text organized by temporal patches
        - metadata: dict - additional sample information
        """
        # Your custom data loading logic
        for sample in self._read_your_data_files():
            # Organize texts by patches (one list per patch)
            num_patches = self.context_len // self.patch_len
            patched_texts = [[] for _ in range(num_patches)]

            # Assign text descriptions to appropriate patches
            for text_item in sample["texts"]:
                patch_idx = self._get_patch_index(text_item["timestamp"])
                patched_texts[patch_idx].append(text_item["text"])

            self.data.append({
                "context": sample["historical_values"],  # shape: (context_len,)
                "future": sample["target_values"],       # shape: (horizon_len,)
                "freq": sample["frequency"],             # 0, 1, or 2
                "patched_texts": patched_texts,          # list of text lists
                "metadata": sample["info"]
            })

2. Train the Model

from multimodal_timesfm.multimodal_patched_decoder import MultimodalPatchedDecoder, MultimodalTimesFMConfig
from multimodal_timesfm.trainer import MultimodalTrainer

# Create datasets
train_dataset = CustomDataset(data_dir="path/to/data", split="train")
val_dataset = CustomDataset(data_dir="path/to/data", split="test")

# Initialize model with pretrained TimesFM weights
config = MultimodalTimesFMConfig(
    text_encoder_type="english",  # or "japanese"
    context_len=128,
    horizon_len=32,
    input_patch_len=32,
)
model = MultimodalPatchedDecoder(config)

# Load pretrained TimesFM checkpoint
model.load_pretrained_timesfm("path/to/timesfm_checkpoint.ckpt")

# Train
trainer = MultimodalTrainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    weight_decay=0.01,
    log_dir=Path("logs"),
    checkpoint_dir=Path("checkpoints"),
    wandb_project="my-project",
    wandb_run_name="experiment-1"
)

# Train for 10 epochs, saving checkpoints every 5 epochs
trainer.train(num_epochs=10, save_frequency=5)

3. Evaluate the Model

from multimodal_timesfm.evaluation import evaluate_multimodal_model
from torch.utils.data import DataLoader
from multimodal_timesfm.utils.collate import multimodal_collate_fn

# Create test dataset
test_dataset = CustomDataset(data_dir="path/to/data", split="test")
test_loader = DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=False,
    collate_fn=multimodal_collate_fn
)

# Load trained model
model = MultimodalPatchedDecoder(config)
checkpoint = torch.load("checkpoints/best_model.pt")
model.load_state_dict(checkpoint["model_state_dict"])

# Evaluate
metrics = evaluate_multimodal_model(model, test_loader, device="cuda")
print(f"MSE: {metrics['mse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")

4. Visualize Predictions

import matplotlib.pyplot as plt
import numpy as np

# Get predictions for visualization
model.eval()
sample = test_dataset[0]
with torch.no_grad():
    prediction = model.forecast(
        context=torch.tensor(sample["context"]).unsqueeze(0),
        text_inputs=[sample["patched_texts"]],
        freq=torch.tensor([sample["freq"]])
    )

# Plot
plt.figure(figsize=(12, 4))
context_len = len(sample["context"])
horizon_len = len(sample["future"])

# Plot context
plt.plot(range(context_len), sample["context"], label="Context", color="blue")

# Plot ground truth
plt.plot(range(context_len, context_len + horizon_len),
         sample["future"], label="Ground Truth", color="green")

# Plot prediction
plt.plot(range(context_len, context_len + horizon_len),
         prediction[0].cpu().numpy(), label="Prediction", color="red", linestyle="--")

plt.xlabel("Time")
plt.ylabel("Value")
plt.legend()
plt.title("Multimodal TimesFM Forecast")
plt.savefig("forecast_visualization.png")

5. Inference on New Data

from multimodal_timesfm import MultimodalTimesFM, TimesFmHparams, MultimodalTimesFMConfig

# Load trained model for inference
hparams = TimesFmHparams(context_len=128, horizon_len=32)
config = MultimodalTimesFMConfig(text_encoder_type="english")
model = MultimodalTimesFM(hparams, config, "checkpoints/best_model.pt")

# Prepare new data
time_series_data = np.array([...])  # Your time series context
text_descriptions = [[
    ["High volatility expected"],        # Texts for patch 1
    ["Market uncertainty increasing"],   # Texts for patch 2
    ["Economic indicators show growth"]  # Texts for patch 3
]]

# Generate forecast
forecasts, quantiles = model.forecast(
    inputs=[time_series_data],
    text_descriptions=text_descriptions,
    freq=[0],  # 0=daily, 1=weekly/monthly, 2=quarterly+
    forecast_context_len=128
)

print(f"Forecast shape: {forecasts.shape}")
print(f"Point forecast: {forecasts[0]}")
print(f"Quantiles shape: {quantiles.shape}")

Features

  • Multimodal forecasting: Combines time series data with textual context
  • Built on TimesFM: Leverages Google's state-of-the-art time series foundation model
  • Flexible text encoding: Supports English and Japanese text inputs
  • Easy integration: Simple API for adding text context to time series forecasting

Time-MMD Dataset Example

The project includes complete scripts for training and evaluating on the Time-MMD dataset.

Setup

Initialize the Time-MMD dataset submodule:

git submodule update --init

The dataset contains multimodal time series data across 10 domains: Agriculture, Climate, Economy, Energy, Environment, Health_AFR, Health_US, Security, SocialGood, and Traffic.

Training with Cross-Validation

Train a multimodal TimesFM model using cross-validation:

PYTHONPATH=. uv run python scripts/train_time_mmd_cv.py \
    --seed 42

Configuration:

  • Model config: TimesFM architecture parameters (layers, dimensions, context length, etc.)
  • Training config: Batch size, learning rate, domains, cross-validation settings, etc.
  • See examples/time_mmd/configs/ for configuration templates

The script will:

  • Create train/validation splits for each cross-validation fold
  • Train a separate model for each fold
  • Save checkpoints and cross-validation results to JSON

Evaluation

Evaluate trained models on the test set:

PYTHONPATH=. uv run python scripts/evaluate_time_mmd_cv.py \
    --cv-results logs/cv_results.json \
    --seed 42

This evaluates both the multimodal model and a baseline TimesFM model (without text inputs), reporting:

  • Mean Squared Error (MSE)
  • Mean Absolute Error (MAE)
  • Per-fold and overall metrics

Visualization

Visualize model predictions:

PYTHONPATH=. uv run python scripts/visualize_time_mmd_cv.py \
    --cv-results logs/cv_results.json

Output:

  • Time series plots showing context, ground truth, and predictions
  • Metric comparison bar charts (MSE and MAE)

Forecasting with Custom Parameters

Generate forecasts with manually configurable context and horizon lengths:

# Forecast on all folds
PYTHONPATH=. uv run python scripts/forecast_time_mmd.py \
    --cv-results logs/cv_results.json \
    --context-len 512 \
    --horizon-len 128

# Forecast on a specific fold
PYTHONPATH=. uv run python scripts/forecast_time_mmd.py \
    --cv-results logs/cv_results.json \
    --fold 0 \
    --context-len 512 \
    --horizon-len 128

# Forecast with custom settings
PYTHONPATH=. uv run python scripts/forecast_time_mmd.py \
    --cv-results logs/cv_results.json \
    --context-len 256 \
    --horizon-len 64 \
    --num-samples 10 \
    --output-dir custom_plots

This script compares multimodal model forecasts against baseline TimesFM forecasts, providing:

  • Time series plots comparing multimodal vs baseline predictions
  • Bar charts comparing MSE and MAE metrics
  • JSON output with all forecasts and metrics

Acknowledgments

We thank the Time-MMD team for providing the multimodal time series dataset used in our examples and experiments.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

multimodal_timesfm-0.1.3.tar.gz (25.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

multimodal_timesfm-0.1.3-py3-none-any.whl (31.8 kB view details)

Uploaded Python 3

File details

Details for the file multimodal_timesfm-0.1.3.tar.gz.

File metadata

  • Download URL: multimodal_timesfm-0.1.3.tar.gz
  • Upload date:
  • Size: 25.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for multimodal_timesfm-0.1.3.tar.gz
Algorithm Hash digest
SHA256 e08e08175f80453692630632a56f352e88baf1121a6418c67ad2558a4bf0d4cb
MD5 2d6297f0b21f6f6650cdb3af3fefdfb0
BLAKE2b-256 d5387cf6696ec091eb4be977fae103e566776a2fd59ffff66c87fc840dfd16a4

See more details on using hashes here.

Provenance

The following attestation bundles were made for multimodal_timesfm-0.1.3.tar.gz:

Publisher: publish-to-pypi.yml on himura467/multimodal-timesfm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file multimodal_timesfm-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for multimodal_timesfm-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 dc707fb4f32a6c373691148df6d4bf8dbc7ce78b81ac43e8f7442d28cd03367a
MD5 6cd3dc948922498eab150e475cd422b1
BLAKE2b-256 1bb38c57c27413d2bfd96fd6eda6e808d31dc89fc8a4149f8861f732fed296d7

See more details on using hashes here.

Provenance

The following attestation bundles were made for multimodal_timesfm-0.1.3-py3-none-any.whl:

Publisher: publish-to-pypi.yml on himura467/multimodal-timesfm

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page