Skip to main content

Transforming conditional density estimation into a single nonparametric regression task.

Project description

condensité

The approach behind condensité enables transforming conditional density estimation into a single nonparametric regression task. Our package implements this approach and provides a way of turning regressors from common libraries such as sklearn or torch into conditional density estimators. See below for a complete example. See example.py for further examples using the reference implementations in src/example_models.

import torch
import condensite as cde
from sklearn.ensemble import HistGradientBoostingRegressor

# create a condensité predictor based on sklearn's HistGradientBoostingRegressor
class TreePredictor(cde.CondensitePredictor):
    def __init__(self, params):
        self.tree = HistGradientBoostingRegressor(**params)

    def device(self):
        return "cpu"

    def to(self, device):
        return self

    def forward(self, x):
        return self.tree.predict(x)

    def predict(self, x):
        return torch.tensor(self.forward(x), dtype=torch.float32)

    def fit(self, dataset: cde.RepeatDataset):
        # create data loader (c.f. cde.RepeatDataset and cde.concat_collate)
        train_DL = torch.utils.data.DataLoader(dataset=dataset,                 
                                               batch_size=dataset.n*dataset.M, 
                                               shuffle=True, 
                                               collate_fn=cde.concat_collate)
        xtorch, ytorch = next(iter(train_DL))  # single batch of all data
        self.tree.fit(xtorch.numpy(), ytorch.numpy())

# instantiate and wrap using Condensite
tree_predictor = TreePredictor(params={
    'max_bins': 40,
    'max_leaf_nodes': 100,
    'l2_regularization': 0.1,
})
sklearn_model = cde.Condensite(tree_predictor, h=0.01, M=100, name='sklearn_model')

# generate some data and fit
def sample_data(n_obs=1000, n_features=10):
    x = torch.empty((n_obs, n_features)).normal_(0,1)
    x1 = x[:, 0]
    std = torch.sqrt(0.25 + x1**2)
    noise = torch.distributions.Normal(torch.zeros_like(x1), std).sample()
    y = x1 + noise
    return x, y

# fit conditional density estimator
x_train, y_train = sample_data()
sklearn_model.fit(x_train, y_train, train_frac=0.8, n_grid_ISE=100)

# evaluate fit out-of-sample
x_test, y_test = sample_data()
test_ISE = cde.utils.ISE(sklearn_model, x_test, y_test, n_grid=100).item()
print(f'{sklearn_model.name}:\t {test_ISE=:.4f}')

If you find our algorithms useful please consider citing

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

condensite-0.0.1.tar.gz (10.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

condensite-0.0.1-py3-none-any.whl (11.8 kB view details)

Uploaded Python 3

File details

Details for the file condensite-0.0.1.tar.gz.

File metadata

  • Download URL: condensite-0.0.1.tar.gz
  • Upload date:
  • Size: 10.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for condensite-0.0.1.tar.gz
Algorithm Hash digest
SHA256 acadfde59625026e9ecbd346e895fe6fafc25819e29fd273d0a3a4481bc2a9a6
MD5 5b8a49f60c8c00d015ace2b5b0789507
BLAKE2b-256 913c32189b6811e3211b462f30ee5ea61c81eed03a60fd4421e620c1d4d87149

See more details on using hashes here.

File details

Details for the file condensite-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: condensite-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 11.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for condensite-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 aea83e766697cce49ff8004b77b49418c29b69d301e77bac3e16cd1667beb35d
MD5 ad5bad235b010d534621741ea88a0723
BLAKE2b-256 ec78b273c72eea78378ca0844d3047c68628c72ed0887d8bb7f467f150670188

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page