Skip to main content

A Python schema-based machine learning library

Project description

Purpose

A Python machine learning library, for creating quick and easy machine learning models. It is schema-based, and wraps scikit-learn.

Usage

Create and use a machine learning model in 3 steps:

  1. Create a schema representing your input and output features.

  2. Train a model from your data.

  3. Make predictions from your model.

Example

To get a feel for the library, consider the classic Iris dataset, where we predict the class of iris plant from measurements of the sepal, and petal.

First, we create a schema describing our inputs and outputs. For our inputs, we have the length, and width, of both the sepal, and the petal. All of these input values happen to be numbers. For our output, we have just the class of iris, which may be one of the labels Iris-setosa, Iris-versicolor, or Iris-virginica.

We define this in code as follows:

from smart_fruit import Model
from smart_fruit.feature_types import Number, Label


class Iris(Model):
    class Input:
        sepal_length_cm = Number()
        sepal_width_cm = Number()
        petal_length_cm = Number()
        petal_width_cm = Number()

    class Output:
        iris_class = Label(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'])

Then, we train a model:

model = Iris.train(Iris.features_from_csv('iris_data.csv'))

with data file iris_data.csv.

sepal_length_cm,sepal_width_cm,petal_length_cm,petal_width_cm,iris_class
5.1,3.5,1.4,0.2,Iris-setosa
...

Finally, we use our new model to make predictions:

for prediction in model.predict([Iris.Input(5.1, 3.5, 1.4, 0.2)]):
    print(prediction.iris_class)

Reference

Models

  • Model.Input - Schema for defining your input features.

  • Model.Output - Schema for defining your output features.

    Define Model.Input and Model.Output as classes with FeatureType attributes.

    eg. Consider the Iris class defined above.

    These classes can then be used to create objects representing the appropriate collections of features.

    eg.

    >>> iris_input = Iris.Input(5.1, 3.5, 1.4, 0.2)
    >>> iris_input
    Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)
    >>> iris_input.sepal_length
    5.1
    
    >>> Iris.Input.from_json({'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2})
    Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)
  • Model.features_from_list(lists) - Deserialize an iterable of lists into an iterable of input/output feature pairs.

    eg.

    >>> list(Iris.features_from_list([[5.1, 3.5, 1.4, 0.2, 'Iris-setosa']]))
    [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa'))]
  • Model.input_features_from_list(lists) - Deserialize an iterable of lists into an iterable of input features.

    eg.

    >>> list(Iris.input_features_from_list([[5.1, 3.5, 1.4, 0.2]]))
    [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)]
  • Model.features_from_json(json) - Deserialize an iterable of dictionaries into an iterable of input/output feature pairs.

    eg.

    >>> list(Iris.features_from_json([{'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2, 'iris_class': 'Iris-setosa'}]))
    [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa'))]
  • Model.input_features_from_json(json) - Deserialize an iterable of dictionaries into an iterable of input features.

    eg.

    >>> list(Iris.input_features_from_json([{'sepal_length_cm': 5.1, 'sepal_width_cm': 3.5, 'petal_length_cm': 1.4, 'petal_width_cm': 0.2}]))
    [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2)]
  • Model.features_from_csv(csv_path) - Take a path to a CSV file, or a file-like object, and deserialize it into an iterable of input/output feature pairs.

    If column headings are not given in the file, assume the input features are followed by the output features, in the order they are defined in their respective classes.

    eg.

    >>> list(Iris.features_from_csv('iris_data.csv'))
    [(Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), Output(iris_class='Iris-setosa')), ...]
  • Model.input_features_from_csv(csv_path) - Take a path to a CSV file, or a file-like object, and deserialize it into an iterable of input features.

    If column headings are not given in the file, assume they are in the order they are defined in the Input class.

    eg.

    >>> list(Iris.input_features_from_csv('iris_data.csv'))
    [Input(sepal_length_cm=5.1, sepal_width_cm=3.5, petal_length_cm=1.4, petal_width_cm=0.2), ...]
  • Model.model_class - How to model the relation between the input and output data.

    Default: sklearn.linear_model.LinearRegression

    This attribute accepts any class with fit, predict, and score methods defined as for scikit-learn multi-response regression models. In particular, this attribute accepts any scikit-learn multi-response regression models, ie. any scikit-learn regression model where the y parameter of fit accepts a numpy array of shape [n_samples, n_targets].

  • Model.train(features, train_test_split_ratio=None, test_sample_count=None, random_state=None)

    Train a new model on the given iterable of input/output pairs.

    Parameters:

    • features - An iterable of input/output pairs.

    • train_test_split_ratio - Proportion of data to use as cross-validation test data.

    • test_sample_count - Number of samples of data to use as cross-validation test data.

      If train_test_split_ratio or test_sample_count are provided, perform cross-validation of the given data. Return both the trained model, and the score of the test data on that model.

    • random_state - Either a numpy RandomState, or the seed to use for the PRNG.

    Useful for getting consistent results, for example for automated tests. Do not use this parameter when generating models you plan to use in production settings.

    eg.

    >>> iris_model = Iris.train([(Iris.Input(5.1, 3.5, 1.4, 0.2), Iris.Output('Iris-setosa'))])
  • model.predict(input_features, yield_inputs=False) - Predict the outputs for a given iterable of inputs.

    If yield_inputs is True then yield the prediction with the input used to generate it, as input, output pairs. Otherwise, yield just the predictions, in the same order the inputs are given to the model.

    eg.

    >>> list(iris_model.predict([Iris.Input(5.1, 3.5, 1.4, 0.2)]))
    [Output(iris_class='Iris-setosa')]

Feature Types

Smart Fruit recognizes the following data types for input and output features. Custom types may be made by extending the FeatureType class.

  • Number() - A real-valued feature.

    eg. 0, 1, 3.141592, -17, …

  • Integer() - A whole number feature.

    eg. 0, 1, 3, -17, …

  • Complex() - A complex-valued number feature.

    eg. 0, 1, 3 + 4j, -1 + 7j, …

  • Label(labels) - An enumerated feature, ie. one which may take one of a pre-defined list of available values.

    eg. For labels = ['red', 'green', 'blue'], our label may take the value 'red', but not 'purple'.

  • Vector(feature_types) - A feature made of other features. Useful for grouping conceptually related features.

    eg. For feature_types = [Number(), Label(['red', 'green', 'blue'])], we may take values such as (0, 'red'), and (1, 'blue').

  • Tag() - A feature that is ignored when making predictions. Useful for keeping track of ID numbers.

    Accepts any Python value.

Requirements

Smart Fruit requires Python 3.6+, and uses scikit-learn, scipy, and pandas.

Installation

Install and update using the standard Python package manager pip:

pip install smart-fruit

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

smart-fruit-1.2.1.tar.gz (9.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

smart_fruit-1.2.1-py3-none-any.whl (13.2 kB view details)

Uploaded Python 3

File details

Details for the file smart-fruit-1.2.1.tar.gz.

File metadata

  • Download URL: smart-fruit-1.2.1.tar.gz
  • Upload date:
  • Size: 9.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.9.1 setuptools/20.7.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.5.2

File hashes

Hashes for smart-fruit-1.2.1.tar.gz
Algorithm Hash digest
SHA256 49645a3ea0a02bddbcab22c920bae5b175e52af771ff5c5cee2994e50dedf313
MD5 bb6ebeedc5d4d984360bfedcacf53669
BLAKE2b-256 1e4f3f7c9ff85ddfbb3661a125f935ba5428b9d7175ba7faebf6b9011fd5755d

See more details on using hashes here.

File details

Details for the file smart_fruit-1.2.1-py3-none-any.whl.

File metadata

  • Download URL: smart_fruit-1.2.1-py3-none-any.whl
  • Upload date:
  • Size: 13.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.11.0 pkginfo/1.4.2 requests/2.9.1 setuptools/20.7.0 requests-toolbelt/0.8.0 tqdm/4.24.0 CPython/3.5.2

File hashes

Hashes for smart_fruit-1.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 402e4c497bfd5f84f597373c9f48c935bdee67c99cd2c0c7e755a344eb73274a
MD5 cfb12dbcbd4c73ab3fe829e1ce7908f2
BLAKE2b-256 b04939fa4197686c5d8b3cb4271b89a511a5410d5db50deba13b71bfa43ee719

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page