Skip to main content

Constructing batched tensors for any machine learning tasks

Project description

Collatable

Actions Status License Python version pypi version

Constructing batched tensors for any machine learning tasks

Installation

pip install collatable

Examples

The following scripts show how to tokenize/index/collate your dataset with collatable:

Text Classification

import collatable
from collatable import LabelField, MetadataField, TextField
from collatable.extras.indexer import LabelIndexer, TokenIndexer

dataset = [
    ("this is awesome", "positive"),
    ("this is a bad movie", "negative"),
    ("this movie is an awesome movie", "positive"),
    ("this movie is too bad to watch", "negative"),
]

# Set up indexers for tokens and labels
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
token_indexer = TokenIndexer[str](specials=[PAD_TOKEN, UNK_TOKEN], default=UNK_TOKEN)
label_indexer = LabelIndexer[str]()

# Load training dataset
instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
    for id_, (text, label) in enumerate(dataset):
        # Prepare each field with the corresponding field class
        text_field = TextField(
            text.split(),
            indexer=token_indexer,
            padding_value=token_indexer[PAD_TOKEN],
        )
        label_field = LabelField(
            label,
            indexer=label_indexer,
        )
        metadata_field = MetadataField({"id": id_})
        # Combine these fields into instance
        instance = dict(
            text=text_field,
            label=label_field,
            metadata=metadata_field,
        )
        instances.append(instance)

# Collate instances and build batch
output = collatable.collate(instances)
print(output)

Execution result:

{'metadata': [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}],
 'text': {
    'token_ids': array([[ 2,  3,  4,  0,  0,  0,  0],
                        [ 2,  3,  5,  6,  7,  0,  0],
                        [ 2,  7,  3,  8,  4,  7,  0],
                        [ 2,  7,  3,  9,  6, 10, 11]]),
    'mask': array([[ True,  True,  True, False, False, False, False],
                   [ True,  True,  True,  True,  True, False, False],
                   [ True,  True,  True,  True,  True,  True, False],
                   [ True,  True,  True,  True,  True,  True,  True]])},
 'label': array([0, 1, 0, 1], dtype=int32)}

Sequence Labeling

import collatable
from collatable import SequenceLabelField, TextField
from collatable.extras.indexer import LabelIndexer, TokenIndexer

dataset = [
    (["my", "name", "is", "john", "smith"], ["O", "O", "O", "B", "I"]),
    (["i", "lived", "in", "japan", "three", "years", "ago"], ["O", "O", "O", "U", "O", "O", "O"]),
]

# Set up indexers for tokens and labels
PAD_TOKEN = "<PAD>"
token_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))
label_indexer = LabelIndexer[str]()

# Load training dataset
instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
    for tokens, labels in dataset:
        text_field = TextField(tokens, indexer=token_indexer, padding_value=token_indexer[PAD_TOKEN])
        label_field = SequenceLabelField(labels, text_field, indexer=label_indexer)
        instance = dict(text=text_field, label=label_field)
        instances.append(instance)

output = collatable.collate(instances)
print(output)

Execution result:

{'label': array([[0, 0, 0, 1, 2, 0, 0],
                 [0, 0, 0, 3, 0, 0, 0]]),
 'text': {
    'token_ids': array([[ 1,  2,  3,  4,  5,  0,  0],
                        [ 6,  7,  8,  9, 10, 11, 12]]),
    'mask': array([[ True,  True,  True,  True,  True, False, False],
                   [ True,  True,  True,  True,  True,  True,  True]])}}

Relation Extraction

import collatable
from collatable.extras.indexer import LabelIndexer, TokenIndexer
from collatable import AdjacencyField, ListField, SpanField, TextField

PAD_TOKEN = "<PAD>"
token_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))
label_indexer = LabelIndexer[str]()

instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
    text = TextField(
        ["john", "smith", "was", "born", "in", "new", "york", "and", "now", "lives", "in", "tokyo"],
        indexer=token_indexer,
        padding_value=token_indexer[PAD_TOKEN],
    )
    spans = ListField([SpanField(0, 2, text), SpanField(5, 7, text), SpanField(11, 12, text)])
    relations = AdjacencyField([(0, 1), (0, 2)], spans, labels=["born-in", "lives-in"], indexer=label_indexer)
    instance = dict(text=text, spans=spans, relations=relations)
    instances.append(instance)

    text = TextField(
        ["tokyo", "is", "the", "capital", "of", "japan"],
        indexer=token_indexer,
        padding_value=token_indexer[PAD_TOKEN],
    )
    spans = ListField([SpanField(0, 1, text), SpanField(5, 6, text)])
    relations = AdjacencyField([(0, 1)], spans, labels=["capital-of"], indexer=label_indexer)
    instance = dict(text=text, spans=spans, relations=relations)
    instances.append(instance)

output = collatable.collate(instances)
print(output)

Execution result:

{'text': {
    'token_ids': array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  5, 11],
                        [11, 12, 13, 14, 15, 16,  0,  0,  0,  0,  0,  0]]),
    'mask': array([[ True,  True,  True,  True,  True,  True,  True,  True,  True, True,  True,  True],
                   [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False]])},
 'spans': array([[[ 0,  2],
                  [ 5,  7],
                  [11, 12]],
                 [[ 0,  1],
                  [ 5,  6],
                  [-1, -1]]]),
 'relations': array([[[-1,  0,  1],
                      [-1, -1, -1],
                      [-1, -1, -1]],
                     [[-1,  2, -1],
                      [-1, -1, -1],
                      [-1, -1, -1]]], dtype=int32)}

Rererence Implementation

extra module provides a reference implementation to use collatable effectively. Here is an example of text-to-text task that encodes raw texts/labels into token ids and decodes them back to raw texts/labels:

from dataclasses import dataclass
from typing import Mapping, Sequence, Union

from collatable.extras import DataLoader, Dataset, DefaultBatchSampler, LabelIndexer, TokenIndexer
from collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform
from collatable.utils import debatched


@dataclass
class Text2TextExample:
    source: Union[str, Sequence[str]]
    target: Union[str, Sequence[str]]
    language: str


text2text_dataset = [
    Text2TextExample(source="how are you?", target="I am fine.", language="en"),
    Text2TextExample(source="what is your name?", target="My name is John.", language="en"),
    Text2TextExample(source="where are you?", target="I am in New-York.", language="en"),
    Text2TextExample(source="what is the time?", target="It is 10:00 AM.", language="en"),
    Text2TextExample(source="comment ça va?", target="Je vais bien.", language="fr"),
]

shared_token_indexer = TokenIndexer(default="<unk>", specials=["<pad>", "<unk>"])
language_indexer = LabelIndexer[str]()

text2text_datamodule = DataModule[Text2TextExample](
    fields={
        "source": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
        "target": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
        "language": LabelFieldTransform(indexer=language_indexer),
    }
)

with shared_token_indexer.context(train=True), language_indexer.context(train=True):
    text2text_datamodule.build(text2text_dataset)

dataloader = DataLoader(DefaultBatchSampler(batch_size=2))

text2text_instances = Dataset.from_iterable(text2text_datamodule(text2text_dataset))

for batch in dataloader(text2text_instances):
    print("Batch:")
    print(batch)
    print("Reconstruction:")
    for item in debatched(batch):
        print(text2text_datamodule.reconstruct(item))
    print()

Execution result:

Batch:
{'target': {
    'token_ids': array([[16, 17, 18, 19,  0],
                        [20,  9,  7, 21, 19]]),
    'mask': array([[ True,  True,  True,  True, False],
                   [ True,  True,  True,  True,  True]])},
    'language': array([0, 0], dtype=int32),
 'source': {
    'token_ids': array([[2, 3, 4, 5, 0],
                        [6, 7, 8, 9, 5]]),
    'mask': array([[ True,  True,  True,  True, False],
                   [ True,  True,  True,  True,  True]])}}
Reconstruction:
{'source': ['how', 'are', 'you', '?'], 'target': ['I', 'am', 'fine', '.'], 'language': 'en'}
{'source': ['what', 'is', 'your', 'name', '?'], 'target': ['My', 'name', 'is', 'John', '.'], 'language': 'en'}

...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

collatable-0.6.0.tar.gz (14.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

collatable-0.6.0-py3-none-any.whl (25.2 kB view details)

Uploaded Python 3

File details

Details for the file collatable-0.6.0.tar.gz.

File metadata

  • Download URL: collatable-0.6.0.tar.gz
  • Upload date:
  • Size: 14.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.0

File hashes

Hashes for collatable-0.6.0.tar.gz
Algorithm Hash digest
SHA256 9b502dab4477083c668a96429c8fcc2ad76c4c1dacd5bf37de81d155f864723f
MD5 00ed8ead8b2a434253273e0c2fd7862f
BLAKE2b-256 efd7c752b1dcdf96759d86fd881c469d4d6c682788bf66e1bdd43095bdc26cfa

See more details on using hashes here.

File details

Details for the file collatable-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: collatable-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 25.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.0

File hashes

Hashes for collatable-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 23d11fe750f77657c0987274f2a7f34356a4f703650d0b13fdf338914d6fc21e
MD5 7c2fedfe77a3b059c5c02a0cf614e689
BLAKE2b-256 1f2c66356ed3957b5ac3ba35949b38fd06878742a8fadd5cc7777e494dc30a99

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page