Investigating belief state representations of transformers trained on Hidden Markov Model emissions
Project description
belief-state-superposition
Quickstart
Install
pip install belief-state-superposition
Usage
Generate and inspect data from a Hidden Markov Model
from belief_state_superposition.hmm import sample_sequence
data = sample_sequence(16)
beliefs, states, emissions, next_beliefs, next_states = zip(*data)
print(beliefs)
print(states)
print(emissions)
Train a model on belief states
import torch
from torch.utils.data import DataLoader
from belief_state_superposition.model import init_model
from belief_state_superposition.data import get_dataset
from belief_state_superposition.train import train_model
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataset = get_dataset(1000)
train_data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
model = init_model().to(device)
train_model(model, train_data_loader, n_epochs=10, show_progress_bar=True, device = device)
Development
Refer to Setup for how to set up development environment.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file belief_state_superposition-0.1.0.tar.gz
.
File metadata
- Download URL: belief_state_superposition-0.1.0.tar.gz
- Upload date:
- Size: 4.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d986338b0e79dbbe0ac8aea934422e60e727acc5eea498b98bc471456cd710d0 |
|
MD5 | e7f9ac24a72d994ab1c518f041e84543 |
|
BLAKE2b-256 | f7c6d38ae8158c15075ecf123903f9617e5527d699a2977862e696435b80adf0 |
File details
Details for the file belief_state_superposition-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: belief_state_superposition-0.1.0-py3-none-any.whl
- Upload date:
- Size: 5.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8c590e7fe29e376b84053eb941316e833616b5c80744844b02bf18855dc65e7a |
|
MD5 | ada2f27f06e0352ec94e1a11b55905e9 |
|
BLAKE2b-256 | 2ee2afe6b5aa7a0e8eb6e7fcc4e23e31143925d1fac3f2fe9466923ee6029672 |