Skip to main content

ktrain is a wrapper for TensorFlow Keras that makes deep learning and AI more accessible and easier to apply

Project description

PyPI Status ktrain python compatibility license Downloads

Welcome to ktrain

News and Announcements

  • 2021-03-10:
    • ktrain v0.26.x is released and now supports transformers>=4.0.0.
      Note that, transformers>=4.0.0 included a complete reogranization of the module's structure. This means that, if you saved a transformers-based Predictor (e.g., DistilBERT) in an older version of ktrain and transformers, you will need to either generate a new tf_model.preproc file or manually edit the existing tf_model.preproc file before loading the predictor in the latest versions of ktrain and transformers.
      For instance, suppose you trained a DistilBERT model and saved the resultant predictor using an older version of ktrain with:'/tmp/my_predictor/'). After upgrading to the newest version of ktrain, you will find that ktrain.load_predictor('/tmp/my_predictor) will throw an error unless you follow one of the two approaches below:

      Approach 1: Manually edit tf_model.preproc file:
      Open tf_model.preproc with an editor like vim and edit it to replace old module locations with new module locations (example changes for a DistilBERT model shown below):

      # change transformers.configuration_distilbert to transformers.models.distilbert.configuration_distilbert
      # change transformers.modeling_tf_auto to
      # change transformers.tokenization_auto to  

      The above was confirmed to work using the vim editor on Linux.

      Approach 2: Re-generate tf_model.preproc file:

      # Step 1: Re-create a Preprocessor instance
      # NOTES:
      # 1. If training set is large, you can use a sample containing at least one example for each class
      # 2. Labels must be in same format as you originally used
      # 3. If original training set is not easily accessible, set preproc.preprocess_train_called=True 
      #    below instead of invoking preproc.preprocess_train(x_train, y_train)
      preproc = text.Transformer(MODEL_NAME, maxlen=500, class_names=class_names)
      trn = preproc.preprocess_train(x_train, y_train)
      # Step 2: load the transformers model from predictor folder
      from transformers import *
      model = TFAutoModelForSequenceClassification.from_pretrained('/tmp/my_predictor/')
      # Step 3: re-create/re-save Predictor
      predictor = ktrain.get_predictor(model, preproc)'/tmp/my_new_predictor')
    • If you're using PyTorch 1.8 or above with ktrain, you will need to upgrade to ktrain>=0.26.0. If you're using ktrain<0.26.0, then you will have to downgrade PyTorch with: pip install torch==1.7.1.

  • 2020-11-08:
    • ktrain v0.25.x is released and includes out-of-the-box support for text extraction via the textract package . This, for example, can be used in the SimpleQA.index_from_folder method to perform Question-Answering on large collections of PDFs, MS Word documents, or PowerPoint files. See the Question-Answering example notebook for more information.
# End-to-End Question-Answering in ktrain

# index documents of different types into a built-in search engine
from ktrain import text
INDEXDIR = '/tmp/myindex'
corpus_path = '/my/folder/of/documents' # contains .pdf, .docx, .pptx files in addition to .txt files
text.SimpleQA.index_from_folder(corpus_path, INDEXDIR, use_text_extraction=True, # enable text extraction
                              multisegment=True, procs=4, # these args speed up indexing
                              breakup_docs=True)          # this slows indexing but speeds up answer retrieval

# ask questions (setting higher batch size can further speed up answer retrieval)
qa = text.SimpleQA(INDEXDIR)
answers = qa.ask('What is ktrain?', batch_size=8)

# top answer snippet extracted from
#   "ktrain is a low-code platform for machine learning"


ktrain is a lightweight wrapper for the deep learning library TensorFlow Keras (and other libraries) to help build, train, and deploy neural networks and other machine learning models. Inspired by ML framework extensions like fastai and ludwig, ktrain is designed to make deep learning and AI more accessible and easier to apply for both newcomers and experienced practitioners. With only a few lines of code, ktrain allows you to easily and quickly:


Please see the following tutorial notebooks for a guide on how to use ktrain on your projects:

Some blog tutorials about ktrain are shown below:

ktrain: A Lightweight Wrapper for Keras to Help Train Neural Networks

BERT Text Classification in 3 Lines of Code

Text Classification with Hugging Face Transformers in TensorFlow 2 (Without Tears)

Build an Open-Domain Question-Answering System With BERT in 3 Lines of Code

Finetuning BERT using ktrain for Disaster Tweets Classification by Hamiz Ahmed


Tasks such as text classification and image classification can be accomplished easily with only a few lines of code.

Example: Text Classification of IMDb Movie Reviews Using BERT [see notebook]

import ktrain
from ktrain import text as txt

# load data
(x_train, y_train), (x_test, y_test), preproc = txt.texts_from_folder('data/aclImdb', maxlen=500, 
                                                                     train_test_names=['train', 'test'],
                                                                     classes=['pos', 'neg'])

# load model
model = txt.text_classifier('bert', (x_train, y_train), preproc=preproc)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, 
                             train_data=(x_train, y_train), 
                             val_data=(x_test, y_test), 

# find good learning rate
learner.lr_find()             # briefly simulate training to find good learning rate
learner.lr_plot()             # visually identify best learning rate

# train using 1cycle learning rate schedule for 3 epochs
learner.fit_onecycle(2e-5, 3) 

Example: Classifying Images of Dogs and Cats Using a Pretrained ResNet50 model [see notebook]

import ktrain
from ktrain import vision as vis

# load data
(train_data, val_data, preproc) = vis.images_from_folder(
                                              data_aug = vis.get_data_aug(horizontal_flip=True),
                                              train_test_names=['train', 'valid'], 
                                              target_size=(224,224), color_mode='rgb')

# load model
model = vis.image_classifier('pretrained_resnet50', train_data, val_data, freeze_layers=80)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model=model, train_data=train_data, val_data=val_data, 
                             workers=8, use_multiprocessing=False, batch_size=64)

# find good learning rate
learner.lr_find()             # briefly simulate training to find good learning rate
learner.lr_plot()             # visually identify best learning rate

# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(1e-4, checkpoint_folder='/tmp/saved_weights') 

Example: Sequence Labeling for Named Entity Recognition using a randomly initialized Bidirectional LSTM CRF model [see notebook]

import ktrain
from ktrain import text as txt

# load data
(trn, val, preproc) = txt.entities_from_txt('data/ner_dataset.csv',
                                            sentence_column='Sentence #',
                                            use_char=True) # enable character embeddings

# load model
model = txt.sequence_tagger('bilstm-crf', preproc)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, train_data=trn, val_data=val)

# conventional training for 1 epoch using a learning rate of 0.001 (Keras default for Adam optmizer), 1) 

Example: Node Classification on Cora Citation Graph using a GraphSAGE model [see notbook]

import ktrain
from ktrain import graph as gr

# load data with supervision ratio of 10%
(trn, val, preproc)  = gr.graph_nodes_from_csv(
                                               'cora.content', # node attributes/labels
                                               'cora.cites',   # edge list
                                              train_pct=0.1, sep='\t')

# load model
model=gr.graph_node_classifier('graphsage', trn)

# wrap model and data in ktrain.Learner object
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=64)

# find good learning rate
learner.lr_find(max_epochs=100) # briefly simulate training to find good learning rate
learner.lr_plot()               # visually identify best learning rate

# train using triangular policy with ModelCheckpoint and implicit ReduceLROnPlateau and EarlyStopping
learner.autofit(0.01, checkpoint_folder='/tmp/saved_weights')

Example: Text Classification with Hugging Face Transformers on 20 Newsgroups Dataset Using DistilBERT [see notebook]

# load text data
categories = ['alt.atheism', 'soc.religion.christian','', '']
from sklearn.datasets import fetch_20newsgroups
train_b = fetch_20newsgroups(subset='train', categories=categories, shuffle=True)
test_b = fetch_20newsgroups(subset='test',categories=categories, shuffle=True)
(x_train, y_train) = (,
(x_test, y_test) = (,

# build, train, and validate model (Transformer is wrapper around transformers library)
import ktrain
from ktrain import text
MODEL_NAME = 'distilbert-base-uncased'
t = text.Transformer(MODEL_NAME, maxlen=500, class_names=train_b.target_names)
trn = t.preprocess_train(x_train, y_train)
val = t.preprocess_test(x_test, y_test)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)
learner.fit_onecycle(5e-5, 4)
learner.validate(class_names=t.get_classes()) # class_names must be string values

# Output from learner.validate()
#                        precision    recall  f1-score   support
#           alt.atheism       0.92      0.93      0.93       319
#       0.97      0.97      0.97       389
#            0.97      0.95      0.96       396
#soc.religion.christian       0.96      0.96      0.96       398
#              accuracy                           0.96      1502
#             macro avg       0.95      0.96      0.95      1502
#          weighted avg       0.96      0.96      0.96      1502

Example: Tabular Classification for Titanic Survival Prediction Using an MLP [see notebook]

import ktrain
from ktrain import tabular
import pandas as pd
train_df = pd.read_csv('train.csv', index_col=0)
train_df = train_df.drop(['Name', 'Ticket', 'Cabin'], 1)
trn, val, preproc = tabular.tabular_from_df(train_df, label_columns=['Survived'], random_state=42)
learner = ktrain.get_learner(tabular.tabular_classifier('mlp', trn), train_data=trn, val_data=val)
learner.lr_find(show_plot=True, max_epochs=5) # estimate learning rate
learner.fit_onecycle(5e-3, 10)

# evaluate held-out labeled test set
tst = preproc.preprocess_test(pd.read_csv('heldout.csv', index_col=0))
learner.evaluate(tst, class_names=preproc.get_classes())

Using ktrain on Google Colab? See these Colab examples:

Additional examples can be found here.


  1. Make sure pip is up-to-date with: pip install -U pip

  2. Install TensorFlow 2 if it is not already installed (e.g., pip install tensorflow)

  3. Install ktrain: pip install ktrain

The above should be all you need on Linux systems and cloud computing environments like Google Colab and AWS EC2. If you are using ktrain on a Windows computer, you can follow these more detailed instructions that include some extra steps.

Some important things to note about installation:

  • If using ktrain with tensorflow<=2.1, you must also downgrade the transformers library to transformers==3.1.
  • If load_predictor fails with the error "AttributeError: 'str' object has no attribute 'decode'", then downgrade h5py: pip install h5py==2.10.0
  • As of v0.21.x, ktrain no longer installs TensorFlow 2 automatically. As indicated above, you should install TensorFlow 2 yourself before installing and using ktrain. On Google Colab, TensorFlow 2 should be already installed. You should be able to use ktrain with any version of TensorFlow 2. Note, however, that there is a bug in TensorFlow 2.2 and 2.3 that affects the Learning-Rate-Finder that will not be fixed until TensorFlow 2.4. The bug causes the learning-rate-finder to complete all epochs even after loss has diverged (i.e., no automatic-stopping).
  • If using ktrain on a local machine with a GPU (versus Google Colab, for example), you'll need to install GPU support for TensorFlow 2.
  • Since some ktrain dependencies have not yet been migrated to tf.keras in TensorFlow 2 (or may have other issues), ktrain is temporarily using forked versions of some libraries. Specifically, ktrain uses forked versions of the eli5 and stellargraph libraries. If not installed, ktrain will complain when a method or function needing either of these libraries is invoked. To install these forked versions, you can do the following:
pip install git+
pip install git+

This code was tested on Ubuntu 18.04 LTS using TensorFlow 2.3.1 and Python 3.6.9.

How to Cite

Please cite the following paper when using ktrain:

    title={ktrain: A Low-Code Library for Augmented Machine Learning},
    author={Arun S. Maiya},
    journal={arXiv preprint arXiv:2004.10703},

Creator: Arun S. Maiya

Email: arun [at] maiya [dot] net

Project details

Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for ktrain, version 0.26.3
Filename, size File type Python version Upload date Hashes
Filename, size ktrain-0.26.3.tar.gz (25.3 MB) File type Source Python version None Upload date Hashes View

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring DigiCert DigiCert EV certificate Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page