Skip to main content

Model hub for transformers.

Project description

Usage Sample ''''''''''''

.. code:: python

    from sklearn.model_selection import train_test_split
    import torch
    from transformers import BertTokenizer
    from nlpx.dataset import TextDataset, text_collate
    from nlpx.model.wrapper import ClassifyModelWrapper
    from transformers_model import AutoCNNTextClassifier, AutoCNNTokenClassifier, \
            BertDataset, BertCollator, BertTokenizeCollator

    texts = [[str],]
    labels = [0, 0, 1, 2, 1...]
    pretrained_path = "clue/albert_chinese_tiny"
    classes = ['class1', 'class2', 'class3'...]
    train_texts, test_texts, y_train, y_test = train_test_split(texts, labels, test_size=0.2)
    
    train_set = TextDataset(train_texts, y_train)
    test_set = TextDataset(test_texts, y_test)

    ################################### TextClassifier ##################################
    model = AutoCNNTextClassifier(pretrained_path, len(classes))
    wrapper = ClassifyModelWrapper(model, classes)
    _ = wrapper.train(train_set, test_set, collate_fn=text_collate)

    ################################### TokenClassifier #################################
    tokenizer = BertTokenizer.from_pretrained(pretrained_path)

    ##################### BertTokenizeCollator #########################
    model = AutoCNNTokenClassifier(pretrained_path, len(classes))
    wrapper = ClassifyModelWrapper(model, classes)
    _ = wrapper.train(train_set, test_set, collate_fn=BertTokenizeCollator(tokenizer, 256))

    ##################### BertCollator ##################################
    train_tokenizies = tokenizer.batch_encode_plus(
            train_texts,
            max_length=256,
            padding="max_length",
            truncation=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors="pt",
    )

    test_tokenizies = tokenizer.batch_encode_plus(
            test_texts,
            max_length=256,
            padding="max_length",
            truncation=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors="pt",
    )

    train_set = BertDataset(train_tokenizies, y_train)
    test_set = BertDataset(test_tokenizies, y_test)

    model = AutoCNNTokenClassifier(pretrained_path, len(classes))
    wrapper = ClassifyModelWrapper(model, classes)
    _ = wrapper.train(train_set, test_set, collate_fn=BertCollator())

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformers-model-0.1.1.tar.gz (8.4 kB view details)

Uploaded Source

File details

Details for the file transformers-model-0.1.1.tar.gz.

File metadata

  • Download URL: transformers-model-0.1.1.tar.gz
  • Upload date:
  • Size: 8.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for transformers-model-0.1.1.tar.gz
Algorithm Hash digest
SHA256 2a3db6ee437c86daf326b18ca86d31b2607e4a0390bc6f723dc2172c8404c7fc
MD5 a24f2fb89a7164c75386531e21e7eda6
BLAKE2b-256 549ee2f6d5fe418f8b2b97f9b9187f7f9785478e060392c46bfca4949a861b4e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page