Skip to main content

Model hub for transformers.

Project description

Usage Sample ''''''''''''

.. code:: python

    import pandas as pd
    from sklearn.model_selection import train_test_split
    import torch
    from transformers import BertTokenizer
    from nlpx.tokenize.utils import get_df_text_labels
    from nlpx.dataset import TextDataset, text_collate
    from transformers_model import AutoCNNTextClassifier, AutoCNNTokenClassifier,BertDataset, BertCollator, BertTokenizeCollator
    from nlpx.model.wrapper import ClassifyModelWrapper

    texts = [[str],]
    labels = [0, 0, 1, 2, 1...]
    pretrained_path = "clue/albert_chinese_tiny"
    classes = ['class1', 'class2', 'class3'...]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_texts, test_texts, y_train, y_test = train_test_split(texts, labels, test_size=0.2, random_state=42)
    
    ######################## AutoCNNTextClassifier classification ##########################
    train_set = TextDataset(train_texts, y_train)
    test_set = TextDataset(test_texts, y_test)
    model = AutoCNNTextClassifier(pretrained_path, len(classes), device)
    wrapper = ClassifyModelWrapper(model, classes, device)
    _ = wrapper.train(train_set, test_set, collate_fn=text_collate)

    ######################### AutoCNNTokenClassifier classification ##########################
    tokenizer = BertTokenizer.from_pretrained(pretrained_path)

    ###################################### BertCollator ######################################
    train_tokenizies = tokenizer.batch_encode_plus(
            train_texts,
            max_length=256,
            padding="max_length",
            truncation=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors="pt",
    )

    test_tokenizies = tokenizer.batch_encode_plus(
            test_texts,
            max_length=256,
            padding="max_length",
            truncation=True,
            return_token_type_ids=True,
            return_attention_mask=True,
            return_tensors="pt",
    )

    train_set = BertDataset(train_tokenizies, y_train)
    test_set = BertDataset(test_tokenizies, y_test)

    model = AutoCNNTokenClassifier(pretrained_path, len(classes), device)
    wrapper = ClassifyModelWrapper(model, classes, device)
    _ = wrapper.train(train_set, test_set, collate_fn=BertCollator())

    ################################ BertTokenizeCollator ################################
    train_set = TextDataset(train_texts, y_train)
    test_set = TextDataset(test_texts, y_test)
    model = AutoCNNTokenClassifier(pretrained_path, len(classes), device)
    wrapper = ClassifyModelWrapper(model, classes, device)
    _ = wrapper.train(train_set, test_set, collate_fn=BertTokenizeCollator(tokenizer, 256))

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformers-model-0.0.3.tar.gz (7.0 kB view details)

Uploaded Source

File details

Details for the file transformers-model-0.0.3.tar.gz.

File metadata

  • Download URL: transformers-model-0.0.3.tar.gz
  • Upload date:
  • Size: 7.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.18

File hashes

Hashes for transformers-model-0.0.3.tar.gz
Algorithm Hash digest
SHA256 f6f07c7b2a86857fa80e7ae4bbb48cc63004d0c5adb98c8bf0c0e76d9b582cb5
MD5 2653948ed20976ea7cba1622c5a4dc8f
BLAKE2b-256 4a73c76fcca51da5f4eb008a67352b70f2621f2e8fa30ec21d1d1af26011696e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page