torch-adata
Project description
torch-adata
Create pytorch Datasets from AnnData
Example use of the base class
The base class, AnnDataset
is a subclass of the widely-used torch.utils.data.Dataset
. The outputs of all AnnDataset
classes and subclasses are designed to be directly compatible with the torch.utils.data.DataLoader
module.
import anndata as a
import torch_adata as ta
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = ta.AnnDataset(adata)
Returns data (X
as a torch.Tensor
) and the pandas.DataFrame
; adata.obs
.
# create a dummy index
idx = np.random.choice(range(dataset.__len__()), 5)
X, obs = dataset.__getitem__(idx)
Specialized classes
GroupedAnnDataset
A subclass of the base class, AnnDataset
.
import anndata as a
import torch_adata as ta
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = ta.GroupedAnnDataset(adata, groupby="batch")
Returns data as a dictionary of data with values as torch.Tensor
and keys as each groupby
category and the sampled adata.obs
is again returned as a pandas.DataFrame
.
# create a dummy index
idx = np.random.choice(range(dataset.__len__()), 5)
X_dict, obs = dataset.__getitem__(idx)
TimeResolvedAnnDataset
A subclass of the class, GroupedAnnDataset
.
import anndata as a
import torch_adata as ta
adata = a.read_h5ad("/path/to/data.h5ad")
dataset = ta.TimeResolvedAnnDataset(adata, time_key="Time point")
Returns the initial datapoint, X0
as a torch.Tensor
, the entire sample of the dataset as a dictionary of data with values as torch.Tensor
and keys as each timepoint indicated by the time_key
. Sampled adata.obs
is again returned as a pandas.DataFrame
.
# create a dummy index
idx = np.random.choice(range(dataset.__len__()), 5)
X0, X_dict, t, obs = dataset.__getitem__(idx)
Installation
Install from PYPI:
pip install torch-adata
Install the developer version:
git clone https://github.com/mvinyard/torch-adata.git; cd torch-adata;
pip install -e .
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_adata-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 82722f9fdc6970d7eabebd5f70a1190af1253ad7ff69932f6ad74cf9e44698e4 |
|
MD5 | 79cf3bd2dca3ec2b2a54816bbee5a7c3 |
|
BLAKE2b-256 | f783633b700543471012eae81a20d8d557532c86546fa26fb4c007bc9bc0b3f1 |