Skip to main content

Text classification datasets

Project description

Textbook: Universal NLP Datasets

Dependency

  • av==6.2.0
  • jsonnet==0.14.0
  • opencv_python==4.1.1.26
  • torch==1.3.1
  • torchvision==0.4.2
  • numpy==1.17.4

Download raw datasets

bash fetch.sh

It downloads alphanli, hellaswag, physicaliqa and socialiqa from AWS.

Usage

Load multichoice dataset

from textbook.config import *
from textbook.transforms_video import *

from transformers import GPT2Tokenizer

text_renderer = TextRenderer(
    tokenizer=GPT2Tokenizer.from_pretrained('distilgpt2'),
    special_tokens={'cls_token': '[CLS]', 'pad_token': '[PAD]', 'mask_token': '[MASK]'},
)
config_alphanli = Configuration(alphanli_config)
alphanli_dataset = ClassificationDataset("data_cache/alphanli/eval.jsonl", config_alphanli, [text_renderer])

alphanli_dev_dataloader: DataLoader = iter(
    DataLoader(
        alphanli_dataset, batch_sampler=DynamicBatchSampler(alphanli_dataset),
        collate_fn=collate_fn))

Load multimodal dataset

upscale_size = int(84 * 1.1)
transform_pre = ComposeMix([
    [Scale(upscale_size), "img"],
    [RandomCropVideo(84), "vid"],
])

transform_post = ComposeMix([
    [torchvision.transforms.ToTensor(), "img"],
])

video_renderer = VisionRenderer(
    nframe=72,
    nclip=1,
    nstep=2,
    transform_pre=transform_pre,
    transform_post=transform_post,
    data_dir="data_cache/smthsmth/20bn-something-something-v2"
)

config_smthsmth = Configuration(smthsmth_config)

smthsmth_dataset = ClassificationVisionDataset(
    "data_cache/smthsmth/something-something-v2-validation.json", config_smthsmth, [text_renderer, video_renderer])

smthsmth_dev_dataloader: DataLoader = iter(DataLoader(
    smthsmth_dataset, batch_size=16, collate_fn=lambda x: collate_fn(
        x, mlm=True, mlm_probability=0.15, tokenizer=text_renderer.tokenizer)))

Let's multitask

multitask_dataloader = MultiTaskDataset([alphanli_dev_dataloader, smthsmth_dev_dataloader])

# alternate through different dataloaders
for batch in multitask_dataloader:
    print(batch["input_ids"].shape)
    print(batch["images"].shape)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

textbook-0.0.8.tar.gz (8.9 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page