Constructing batched tensors for any machine learning tasks
Project description
Collatable
Constructing batched tensors for any machine learning tasks
Installation
pip install collatable
Examples
The following scripts show how to tokenize/index/collate your dataset with collatable:
Text Classification
import collatable
from collatable import LabelField, MetadataField, TextField
from collatable.extras.indexer import LabelIndexer, TokenIndexer
dataset = [
("this is awesome", "positive"),
("this is a bad movie", "negative"),
("this movie is an awesome movie", "positive"),
("this movie is too bad to watch", "negative"),
]
# Set up indexers for tokens and labels
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
token_indexer = TokenIndexer[str](specials=[PAD_TOKEN, UNK_TOKEN], default=UNK_TOKEN)
label_indexer = LabelIndexer[str]()
# Load training dataset
instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
for id_, (text, label) in enumerate(dataset):
# Prepare each field with the corresponding field class
text_field = TextField(
text.split(),
indexer=token_indexer,
padding_value=token_indexer[PAD_TOKEN],
)
label_field = LabelField(
label,
indexer=label_indexer,
)
metadata_field = MetadataField({"id": id_})
# Combine these fields into instance
instance = dict(
text=text_field,
label=label_field,
metadata=metadata_field,
)
instances.append(instance)
# Collate instances and build batch
output = collatable.collate(instances)
print(output)
Execution result:
{'metadata': [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}],
'text': {
'token_ids': array([[ 2, 3, 4, 0, 0, 0, 0],
[ 2, 3, 5, 6, 7, 0, 0],
[ 2, 7, 3, 8, 4, 7, 0],
[ 2, 7, 3, 9, 6, 10, 11]]),
'mask': array([[ True, True, True, False, False, False, False],
[ True, True, True, True, True, False, False],
[ True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True]])},
'label': array([0, 1, 0, 1], dtype=int32)}
Sequence Labeling
import collatable
from collatable import SequenceLabelField, TextField
from collatable.extras.indexer import LabelIndexer, TokenIndexer
dataset = [
(["my", "name", "is", "john", "smith"], ["O", "O", "O", "B", "I"]),
(["i", "lived", "in", "japan", "three", "years", "ago"], ["O", "O", "O", "U", "O", "O", "O"]),
]
# Set up indexers for tokens and labels
PAD_TOKEN = "<PAD>"
token_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))
label_indexer = LabelIndexer[str]()
# Load training dataset
instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
for tokens, labels in dataset:
text_field = TextField(tokens, indexer=token_indexer, padding_value=token_indexer[PAD_TOKEN])
label_field = SequenceLabelField(labels, text_field, indexer=label_indexer)
instance = dict(text=text_field, label=label_field)
instances.append(instance)
output = collatable.collate(instances)
print(output)
Execution result:
{'label': array([[0, 0, 0, 1, 2, 0, 0],
[0, 0, 0, 3, 0, 0, 0]]),
'text': {
'token_ids': array([[ 1, 2, 3, 4, 5, 0, 0],
[ 6, 7, 8, 9, 10, 11, 12]]),
'mask': array([[ True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True]])}}
Relation Extraction
import collatable
from collatable.extras.indexer import LabelIndexer, TokenIndexer
from collatable import AdjacencyField, ListField, SpanField, TextField
PAD_TOKEN = "<PAD>"
token_indexer = TokenIndexer[str](specials=(PAD_TOKEN,))
label_indexer = LabelIndexer[str]()
instances = []
with token_indexer.context(train=True), label_indexer.context(train=True):
text = TextField(
["john", "smith", "was", "born", "in", "new", "york", "and", "now", "lives", "in", "tokyo"],
indexer=token_indexer,
padding_value=token_indexer[PAD_TOKEN],
)
spans = ListField([SpanField(0, 2, text), SpanField(5, 7, text), SpanField(11, 12, text)])
relations = AdjacencyField([(0, 1), (0, 2)], spans, labels=["born-in", "lives-in"], indexer=label_indexer)
instance = dict(text=text, spans=spans, relations=relations)
instances.append(instance)
text = TextField(
["tokyo", "is", "the", "capital", "of", "japan"],
indexer=token_indexer,
padding_value=token_indexer[PAD_TOKEN],
)
spans = ListField([SpanField(0, 1, text), SpanField(5, 6, text)])
relations = AdjacencyField([(0, 1)], spans, labels=["capital-of"], indexer=label_indexer)
instance = dict(text=text, spans=spans, relations=relations)
instances.append(instance)
output = collatable.collate(instances)
print(output)
Execution result:
{'text': {
'token_ids': array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 5, 11],
[11, 12, 13, 14, 15, 16, 0, 0, 0, 0, 0, 0]]),
'mask': array([[ True, True, True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, False, False, False, False, False, False]])},
'spans': array([[[ 0, 2],
[ 5, 7],
[11, 12]],
[[ 0, 1],
[ 5, 6],
[-1, -1]]]),
'relations': array([[[-1, 0, 1],
[-1, -1, -1],
[-1, -1, -1]],
[[-1, 2, -1],
[-1, -1, -1],
[-1, -1, -1]]], dtype=int32)}
Rererence Implementation
extra module provides a reference implementation to use collatable effectively.
Here is an example of text-to-text task that encodes raw texts/labels into token
ids and decodes them back to raw texts/labels:
from dataclasses import dataclass
from typing import Mapping, Sequence, Union
from collatable.extras import DataLoader, Dataset, DefaultBatchSampler, LabelIndexer, TokenIndexer
from collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform
from collatable.utils import debatched
@dataclass
class Text2TextExample:
source: Union[str, Sequence[str]]
target: Union[str, Sequence[str]]
language: str
text2text_dataset = [
Text2TextExample(source="how are you?", target="I am fine.", language="en"),
Text2TextExample(source="what is your name?", target="My name is John.", language="en"),
Text2TextExample(source="where are you?", target="I am in New-York.", language="en"),
Text2TextExample(source="what is the time?", target="It is 10:00 AM.", language="en"),
Text2TextExample(source="comment ça va?", target="Je vais bien.", language="fr"),
]
shared_token_indexer = TokenIndexer(default="<unk>", specials=["<pad>", "<unk>"])
language_indexer = LabelIndexer[str]()
text2text_datamodule = DataModule[Text2TextExample](
fields={
"source": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
"target": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
"language": LabelFieldTransform(indexer=language_indexer),
}
)
with shared_token_indexer.context(train=True), language_indexer.context(train=True):
text2text_datamodule.build(text2text_dataset)
dataloader = DataLoader(DefaultBatchSampler(batch_size=2))
text2text_instances = Dataset.from_iterable(text2text_datamodule(text2text_dataset))
for batch in dataloader(text2text_instances):
print("Batch:")
print(batch)
print("Reconstruction:")
for item in debatched(batch):
print(text2text_datamodule.reconstruct(item))
print()
Execution result:
Batch:
{'target': {
'token_ids': array([[16, 17, 18, 19, 0],
[20, 9, 7, 21, 19]]),
'mask': array([[ True, True, True, True, False],
[ True, True, True, True, True]])},
'language': array([0, 0], dtype=int32),
'source': {
'token_ids': array([[2, 3, 4, 5, 0],
[6, 7, 8, 9, 5]]),
'mask': array([[ True, True, True, True, False],
[ True, True, True, True, True]])}}
Reconstruction:
{'source': ['how', 'are', 'you', '?'], 'target': ['I', 'am', 'fine', '.'], 'language': 'en'}
{'source': ['what', 'is', 'your', 'name', '?'], 'target': ['My', 'name', 'is', 'John', '.'], 'language': 'en'}
...
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file collatable-0.6.0.tar.gz.
File metadata
- Download URL: collatable-0.6.0.tar.gz
- Upload date:
- Size: 14.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9b502dab4477083c668a96429c8fcc2ad76c4c1dacd5bf37de81d155f864723f
|
|
| MD5 |
00ed8ead8b2a434253273e0c2fd7862f
|
|
| BLAKE2b-256 |
efd7c752b1dcdf96759d86fd881c469d4d6c682788bf66e1bdd43095bdc26cfa
|
File details
Details for the file collatable-0.6.0-py3-none-any.whl.
File metadata
- Download URL: collatable-0.6.0-py3-none-any.whl
- Upload date:
- Size: 25.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
23d11fe750f77657c0987274f2a7f34356a4f703650d0b13fdf338914d6fc21e
|
|
| MD5 |
7c2fedfe77a3b059c5c02a0cf614e689
|
|
| BLAKE2b-256 |
1f2c66356ed3957b5ac3ba35949b38fd06878742a8fadd5cc7777e494dc30a99
|