Investigating belief state representations of transformers trained on Hidden Markov Model emissions
Project description
belief-state-superposition
Quickstart
Install
pip install belief-state-superposition
Usage
Generate and inspect data from a Hidden Markov Model
from belief_state_superposition.hmm import sample_sequence
data = sample_sequence(16)
beliefs, states, emissions, next_beliefs, next_states = zip(*data)
print(beliefs)
print(states)
print(emissions)
Train a model on belief states
import torch
from torch.utils.data import DataLoader
from belief_state_superposition.model import init_model
from belief_state_superposition.data import get_dataset
from belief_state_superposition.train import train_model
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataset = get_dataset(1000)
train_data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
model = init_model().to(device)
train_model(model, train_data_loader, n_epochs=10, show_progress_bar=True, device = device)
Development
Refer to Setup for how to set up development environment.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for belief_state_superposition-0.1.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | d986338b0e79dbbe0ac8aea934422e60e727acc5eea498b98bc471456cd710d0 |
|
MD5 | e7f9ac24a72d994ab1c518f041e84543 |
|
BLAKE2b-256 | f7c6d38ae8158c15075ecf123903f9617e5527d699a2977862e696435b80adf0 |
Close
Hashes for belief_state_superposition-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8c590e7fe29e376b84053eb941316e833616b5c80744844b02bf18855dc65e7a |
|
MD5 | ada2f27f06e0352ec94e1a11b55905e9 |
|
BLAKE2b-256 | 2ee2afe6b5aa7a0e8eb6e7fcc4e23e31143925d1fac3f2fe9466923ee6029672 |