No project description provided
Project description
ai-utils
Installation
pip install cyvidia-ai-utils
Cross validation
Huggingface Trainer API
from cyvidia_ai_utils import TransformerCrossValidationModel, cross_validate, EvaluationResult
from transformers import AutoTokenizer, AutoModelForSequenceClassification,TrainingArguments
from datasets import load_dataset, Dataset
from transformers import DataCollatorWithPadding,Trainer
folds = load_dataset("dipesh/Intent-Classification-small",split=[f"train[{k}%:{k+10}%]" for k in range(0, 100, 10)])
assert(isinstance(folds, list))
def create_trainer(model, tokenizer, train_ds: Dataset, val_ds:Dataset):
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_train_ds = train_ds.map(preprocess_function, batched=True)
tokenized_val_ds = val_ds.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir=f'tests/test_models/{uuid.uuid4()}',
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=0.001,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_ds,
eval_dataset=tokenized_val_ds,
tokenizer=tokenizer,
data_collator=data_collator,
)
return trainer
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=21)
tokenizer= AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
cross_val_model= TransformerCrossValidationModel(
model= model,
tokenizer= tokenizer,
create_trainer= create_trainer
)
results= cross_validate(cross_val_model, folds, target_id_column="label", input_text_column="text")
aggragated_result= EvaluationResult.aggregate(list(results.values()))
Custom Trainer
from cyvidia_ai_utils import CrossValidationModel
class MyCustomCrossValidationModel(CrossValidationModel):
def get_label_for_id(self, id: int)-> str:
# Implement
def train(self, train_ds, val_ds)-> CrossValidationModel:
# Implement
def predict_values(self, values)-> Dict[str,Any]:
# Implement
results= cross_validate(MyCustomCrossValidationModel(), folds, target_id_column="label", input_text_column="text")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
cyvidia_ai_utils-0.5.0.tar.gz
(3.7 kB
view hashes)
Built Distribution
Close
Hashes for cyvidia_ai_utils-0.5.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 717dfab4fb77a7026604c406084dd232ddb7cb6c9d1acb2f3dcdafb907804fcb |
|
MD5 | 8ff65f3d276e62f4f7e35524f231ab98 |
|
BLAKE2b-256 | e2c8c5e43e568bd87f8709e84e249993695ccdd731586028dacdef94df6c1a49 |