Skip to main content

Investigating belief state representations of transformers trained on Hidden Markov Model emissions

Project description

belief-state-superposition

[Github Actions] Ruff pdm-managed Checked with pyright

Quickstart

Install

pip install belief-state-superposition

Usage

Generate and inspect data from a Hidden Markov Model

from belief_state_superposition.hmm import sample_sequence

data = sample_sequence(16)
beliefs, states, emissions, next_beliefs, next_states = zip(*data)
print(beliefs)
print(states)
print(emissions)

Train a model on belief states

import torch 
from torch.utils.data import DataLoader
from belief_state_superposition.model import init_model
from belief_state_superposition.data import get_dataset
from belief_state_superposition.train import train_model

device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataset = get_dataset(1000)
train_data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
model = init_model().to(device)
train_model(model, train_data_loader, n_epochs=10, show_progress_bar=True, device = device)

Development

Refer to Setup for how to set up development environment.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

belief_state_superposition-0.1.0.tar.gz (4.5 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file belief_state_superposition-0.1.0.tar.gz.

File metadata

File hashes

Hashes for belief_state_superposition-0.1.0.tar.gz
Algorithm Hash digest
SHA256 d986338b0e79dbbe0ac8aea934422e60e727acc5eea498b98bc471456cd710d0
MD5 e7f9ac24a72d994ab1c518f041e84543
BLAKE2b-256 f7c6d38ae8158c15075ecf123903f9617e5527d699a2977862e696435b80adf0

See more details on using hashes here.

File details

Details for the file belief_state_superposition-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for belief_state_superposition-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8c590e7fe29e376b84053eb941316e833616b5c80744844b02bf18855dc65e7a
MD5 ada2f27f06e0352ec94e1a11b55905e9
BLAKE2b-256 2ee2afe6b5aa7a0e8eb6e7fcc4e23e31143925d1fac3f2fe9466923ee6029672

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page