The Mixture Adaptive Design (MAD): An experimental design for anytime-valid causal inference on Multi-Armed Bandits.
Project description
pyssed
The goal of pyssed is to implement the Mixture Adaptive Design (MAD), as proposed by Liang and Bojinov. MAD is an experimental design for multi-armed bandit algorithms that enables anytime-valid inference on the Average Treatment Effect (ATE).
Intuitively, MAD “mixes” any bandit algorithm with a Bernoulli design, where at each time step, the probability of assigning a unit via the Bernoulli design is determined by a user-specified deterministic sequence that can converge to zero. This sequence lets managers directly control the trade-off between regret minimization and inferential precision. Under mild conditions on the rate the sequence converges to zero, [MAD] provides a confidence sequence that is asymptotically anytime-valid and guaranteed to shrink around the true ATE. Hence, when the true ATE converges to a non-zero value, the MAD confidence sequence is guaranteed to exclude zero in finite time. Therefore, the MAD enables managers to stop experiments early while ensuring valid inference, enhancing both the efficiency and reliability of adaptive experiments.
Installation
pyssed can be installed from PyPI with:
pip install pyssed
or from GitHub with:
pip install git+https://github.com/dmolitor/pyssed
Usage
We’ll simulate an experiment with three treatment arms and one control arm using Thompson Sampling (TS) as the bandit algorithm. We’ll demonstrate how MAD enables unbiased ATE estimation for all treatments while maintaining valid confidence sequences.
First, import the necessary packages:
import numpy as np
import pandas as pd
import plotnine as pn
from pyssed import Bandit, MAD
from typing import Callable, Dict
generator = np.random.default_rng(seed=123)
Treatment arm outcomes
Next, define a function to generate outcomes (rewards) for each experiment arm:
def reward_fn(arm: int) -> float:
values = {
0: generator.binomial(1, 0.5), # Control arm
1: generator.binomial(1, 0.6), # ATE = 0.1
2: generator.binomial(1, 0.7), # ATE = 0.2
3: generator.binomial(1, 0.72), # ATE = 0.22
}
return values[arm]
We design the experiment so Arm 1 has a small ATE (0.1), while Arms 2 and 3 have larger ATEs (0.2 and 0.22) that are very similar.
Thompson Sampling bandit
We’ll now implement TS for binary data, modeling each arm’s outcomes as drawn from a Bernoulli with an unknown parameter $\theta$, where $\theta$ follows a Beta($\alpha$=1, $\beta$=1) prior (a uniform prior).
To use MAD, pyssed requires the bandit algorithm to be a class
inheriting from pyssed.Bandit, which requires the bandit class to
implement the following key methods:
control(): Returns the control arm index.k(): Returns the number of arms.probabilities(): Computes arm assignment probabilities.reward(arm): Computes the reward for a selected arm.t()Returns the current time step.
For full details, see the pyssed.Bandit documentation.
The following is an example TS implementation:
class TSBernoulli(Bandit):
"""
A class for implementing Thompson Sampling on Bernoulli data
"""
def __init__(self, k: int, control: int, reward: Callable[[int], float]):
self._active_arms = [x for x in range(k)]
self._control = control
self._k = k
self._means = {x: 0. for x in range(k)}
self._params = {x: {"alpha": 1, "beta": 1} for x in range(k)}
self._rewards = {x: [] for x in range(k)}
self._reward_fn = reward
self._t = 1
def control(self) -> int:
return self._control
def k(self) -> int:
return self._k
def probabilities(self) -> Dict[int, float]:
assert self.k() == len(self._active_arms), "Mismatch in `len(self._active_arms)` and `self.k()`"
samples = np.column_stack([
np.random.beta(
a=self._params[idx]["alpha"],
b=self._params[idx]["beta"],
size=1000
)
for idx in self._active_arms
])
max_indices = np.argmax(samples, axis=1)
probs = {
idx: np.sum(max_indices == i) / 1000
for i, idx in enumerate(self._active_arms)
}
return probs
def reward(self, arm: int) -> float:
outcome = self._reward_fn(arm)
self._rewards[arm].append(outcome)
if outcome == 1:
self._params[arm]["alpha"] += 1
else:
self._params[arm]["beta"] += 1
self._means[arm] = (
self._params[arm]["alpha"]
/(self._params[arm]["alpha"] + self._params[arm]["beta"])
)
return outcome
def t(self) -> int:
step = self._t
self._t += 1
return step
With our TS bandit algorithm implemented, we can now wrap it in the MAD experimental design for inference on the ATEs!
The MAD
For our MAD design, we need a function that takes the time step $t$ and computes a sequence converging to 0. The key requirement is that this sequence must decay slower than $1/(t^{1/4})$.
Intuitively, a sequence of $1/t^0 = 1$ corresponds to Bernoulli randomization, while a sequence of $1/(t^{0.24})$ closely follows the TS assignment policy.
In this example, we use $1/(t^{0.24})$ as our sequence. Additionally, we
estimate 95% confidence sequences, setting our test size $\alpha=0.05$.
We run the experiment for 20,000 iterations (t_star = 20000).
experiment = MAD(
bandit=TSBernoulli(k=4, control=0, reward=reward_fn),
alpha=0.05,
delta=lambda x: 1./(x**0.24),
t_star=int(20e3)
)
experiment.fit(verbose=False)
Point estimates and confidence bands
Now, to examine the results we can print a summary of the ATEs and their corresponding confidence sequences at the end of the experiment:
experiment.summary()
Treatment effect estimates:
- Arm 1: 0.055 (-0.09394, 0.20355)
- Arm 2: 0.157 (0.03118, 0.28264)
- Arm 3: 0.197 (0.09508, 0.29887)
We can also extract this summary into a pandas DataFrame:
experiment.estimates()
| arm | ate | lb | ub | |
|---|---|---|---|---|
| 0 | 1 | 0.054806 | -0.093939 | 0.203551 |
| 1 | 2 | 0.156910 | 0.031179 | 0.282641 |
| 2 | 3 | 0.196978 | 0.095082 | 0.298873 |
3 rows × 4 columns
Plotting results
We can also visualize the ATE estimates and confidence sequences for each treatment arm over time.
(
experiment.plot_ate()
+ pn.coord_cartesian(ylim=(-.5, 1.0))
+ pn.geom_hline(
mapping=pn.aes(yintercept="ate", color="factor(arm)"),
data=pd.DataFrame({"arm": [1, 2, 3], "ate": [0.1, 0.2, 0.22]}),
linetype="dotted"
)
+ pn.theme(strip_text=pn.element_blank())
)
The ATE estimates converge toward the ground truth, and the confidence sequences maintain valid coverage!
We can also examine the algorithm’s sample assignment strategy over time.
experiment.plot_sample_assignment()
Due to the TS algorithm, most samples go to the optimal Arm 3 and secondary Arm 2, with some random exploration in Arms 0 and 1.
Similarly, we can plot the total sample assignments per arm.
experiment.plot_n()
Equivalence to a completely randomized design
As noted above, setting the time-diminishing sequence $\delta_t = 1/t^0 = 1$ results in a fully randomized design. We can easily demonstrate this:
exp_bernoulli = MAD(
bandit=TSBernoulli(k=4, control=0, reward=reward_fn),
alpha=0.05,
delta=lambda _: 1.,
t_star=int(20e3)
)
exp_bernoulli.fit(verbose=False)
As before, we can plot the convergence of the estimated ATEs to the ground truth:
(
exp_bernoulli.plot_ate()
+ pn.coord_cartesian(ylim=(-.1, 0.6))
+ pn.geom_hline(
mapping=pn.aes(yintercept="ate", color="factor(arm)"),
data=pd.DataFrame({"arm": [1, 2, 3], "ate": [0.1, 0.2, 0.22]}),
linetype="dotted"
)
+ pn.theme(strip_text=pn.element_blank())
)
And we can verify fully random assignment:
exp_bernoulli.plot_n()
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pyssed-0.1.1.tar.gz.
File metadata
- Download URL: pyssed-0.1.1.tar.gz
- Upload date:
- Size: 640.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.26
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
733ff68587d197dba0bbee522c04795d8cdc9ea4bdaccf158c8e34dc1197f9ed
|
|
| MD5 |
92f5447a5f826c6e9429e7c858649b5b
|
|
| BLAKE2b-256 |
b8f59fea2315875ac33df25b4f148bc2d72ddec0c3a265f2bd7ad61dbdb69207
|
File details
Details for the file pyssed-0.1.1-py3-none-any.whl.
File metadata
- Download URL: pyssed-0.1.1-py3-none-any.whl
- Upload date:
- Size: 9.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.26
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
97f428a8f0792da67a4c0317810438bfb037f51a95f2e3bb9404ca220a3f7d64
|
|
| MD5 |
9cb7b74bdd7aef2276e238384eb697df
|
|
| BLAKE2b-256 |
e12bc5520ba230d0ee3604146d25130a17fdc8059a66db2aeb233a9627d5ba39
|