No project description provided
Project description
ai-utils
Cross validation
Huggingface Trainer API
from cyvidia_ai_utils import TransformerCrossValidationModel, cross_validate
from transformers import AutoTokenizer, AutoModelForSequenceClassification,TrainingArguments
from datasets import load_dataset, Dataset
from transformers import DataCollatorWithPadding,Trainer
folds = load_dataset("dipesh/Intent-Classification-small",split=[f"train[{k}%:{k+10}%]" for k in range(0, 100, 10)])
assert(isinstance(folds, list))
def create_trainer(model, tokenizer, train_ds: Dataset, val_ds:Dataset):
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_train_ds = train_ds.map(preprocess_function, batched=True)
tokenized_val_ds = val_ds.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir=f'tests/test_models/{uuid.uuid4()}',
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=0.001,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_ds,
eval_dataset=tokenized_val_ds,
tokenizer=tokenizer,
data_collator=data_collator,
)
return trainer
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=21)
tokenizer= AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
cross_val_model= TransformerCrossValidationModel(
model= model,
tokenizer= tokenizer,
create_trainer= create_trainer
)
results= cross_validate(cross_val_model, folds, target_id_column="label", input_text_column="text")
Custom Trainer
from cyvidia_ai_utils import CrossValidationModel
class MyCustomCrossValidationModel(CrossValidationModel):
def get_label_for_id(self, id: int)-> str:
# Implement
def train(self, train_ds, val_ds)-> CrossValidationModel:
# Implement
def predict_values(self, values)-> Dict[str,Any]:
# Implement
results= cross_validate(MyCustomCrossValidationModel(), folds, target_id_column="label", input_text_column="text")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
cyvidia_ai_utils-0.4.0.tar.gz
(3.2 kB
view hashes)
Built Distribution
Close
Hashes for cyvidia_ai_utils-0.4.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e42630ec90fdf74c62b045deacde7110760d1f71426b14092313c29eb6eb2e54 |
|
MD5 | 67b87cb6ec8e11658d1cdf03daebde1d |
|
BLAKE2b-256 | 8d08d39cd80a3ab9db3dcf809815d95e96555a73cfc306d2ebfa3b71911cadbe |