No project description provided
Project description
ai-utils
Cross validation
Huggingface Trainer API
from cyvidia_ai_utils import TransformerCrossValidationModel, cross_validate
from transformers import AutoTokenizer, AutoModelForSequenceClassification,TrainingArguments
from datasets import load_dataset, Dataset
from transformers import DataCollatorWithPadding,Trainer
folds = load_dataset("dipesh/Intent-Classification-small",split=[f"train[{k}%:{k+10}%]" for k in range(0, 100, 10)])
tokenizer= AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
assert(isinstance(folds, list))
def create_trainer(train_ds: Dataset, val_ds:Dataset):
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_train_ds = train_ds.map(preprocess_function, batched=True)
tokenized_val_ds = val_ds.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir=f'tests/test_models/{uuid.uuid4()}',
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=0.001,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_ds,
eval_dataset=tokenized_val_ds,
tokenizer=tokenizer,
data_collator=data_collator,
)
return trainer
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=21)
cross_val_model= TransformerCrossValidationModel(
model= model,
tokenizer= tokenizer,
create_trainer= create_trainer
)
results= cross_validate(cross_val_model, folds, target_id_column="label", input_text_column="text")
Custom Trainer
from cyvidia_ai_utils import CrossValidationModel
class MyCustomCrossValidationModel(CrossValidationModel):
def get_label_for_id(self, id: int)-> str:
# Implement
def train(self, train_ds, val_ds)-> CrossValidationModel:
# Implement
def predict_values(self, values)-> Dict[str,Any]:
# Implement
results= cross_validate(MyCustomCrossValidationModel(), folds, target_id_column="label", input_text_column="text")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
cyvidia_ai_utils-0.3.0.tar.gz
(3.2 kB
view hashes)
Built Distribution
Close
Hashes for cyvidia_ai_utils-0.3.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6e38f575780d5fcc0a09815f1e75d7cd1b427cb53b4a3d84ff08ecd1aab350de |
|
MD5 | 0f0adcfdbe6a34e872947721b9ae01c6 |
|
BLAKE2b-256 | cef8e4e69386ff0da58862e6b49528588c5680d36ab7012b604889f208fef301 |