Skip to main content

the best neural network library

Project description

Floral

header

The best neural network library

Floral is a neural network library, created in Jax, by Cameron Ryan. In floral, every tensor and operation is a graph node, and graphs are both inferenced and optimized through the same probe tracing algorithm. The benefit of floral is that it's simple and efficient graph algorithm provides an easy interface with low level features.

installation

install with pip

pip install floral

getting started

To use floral, you must create a graph by linking nodes together. Let's first define a neural network using the floral.graph.GraphModule class.

from floral import nn, graph, datasets, loss, optim

class Model(graph.GraphModule):
    def __init__(self):
        self.input = nn.Input()
        self.linear1 = nn.Linear(self.input,[64, 784])
        self.relu1 = nn.ReLU(self.linear1.link)
        self.linear2 = nn.Linear(self.relu1, [64, 64])
        self.relu2 = nn.ReLU(self.linear2.link)
        self.linear3 = nn.Linear(self.relu2, [10,64])

        self.crossentropy = loss.CategoricalCrossEntropy(self.linear3.link)
        
model = Model()

When constructing a graph in floral, there exists floral.graph.GraphNode objects, and floral.graph.GraphModule objects. All of a graph's functionality comes from the floral.graph.GraphNode objects, which either store data, or perform functions, and are linked to parent nodes. The floral.graph.GraphModule objects simply contain node objects, and exist only for abstraction. All floral.graph.GraphModule objects must have a link attribute, which is a reference to the last node in their graph.

lets load the MNIST dataset to train our nerual network on.

mnist = datasets.MNIST()

When we want to inference our graph, we attach the variable tensors to their respective nodes, in this case the model's input node, and loss node, and use the floral.graph.forward_trace(node) method to get the node's output, which is the model's loss in this case.

def inference(input_link, loss_link, x, y):
   input_link.attach(x)
   loss_link.attach(y)
   out = graph.forward_trace(loss_link)
   graph.clear_cache(loss_link)
   return out

lets grab a sample image, and label, and inference it on the graph

sample_image, sample_label = mnist[0]
print(inference(model.input, model.crossentropy, sample_image, sample_label))

After inferencing a graph, we can use the floral.graph.gradient_trace(node) method to calculate gradients for each tensor in the graph, and then optimize them with a floral.graph.OptimizationProbe object. It is also very important to clear the graph's cache before it is traced again, through the floral.graph.clear_cache(node) method

def optimize(optim_probe, input_link, loss_link, x, y):
   input_link.attach(x)
   loss_link.attach(y)

   loss = graph.forward_trace(loss_link)
   graph.gradient_trace(loss_link)
   optim_probe.trace(loss_link)

   graph.clear_cache(loss_link)
   return loss

To make an optimization probe, we need a floral.optim.Optimizer object. For this, we will use floral.optim.StochasticGradientDescent.

optim_probe = graph.OptimizationProbe(optim.StochasticGradientDescent(lr=0.01))

Now lets optimize the loss on our sample image, and sample label.

optimize(optim_probe, model.input, model.crossentropy, sample_image, sample_label)
print(inference(model.input, model.crossentropy, sample_image, sample_label))

Lets also make an evaluation function.

def evaluate(test_set, input_link, loss_link):
    image_set, label_set = test_set
    total_loss = 0
    for i in range(len(image_set)):
        image, label = image_set[i], label_set[i]
        total_loss += inference(input_link, loss_link, image, label)
    return total_loss / len(image_set)
    
test_images, test_labels = mnist[:2000]
print("starting loss: ",evaluate((test_images, test_labels), model.input, model.crossentropy))

Now, we can train our model for one epoch. For the purposes of this tutorial, this should allow you to achieve a reasonable accuracy for your model.

train_images, train_labels = mnist[2000:10000]
for i in range(len(train_images)):
    image, label = train_images[i], train_labels[i]
    optimize(optim_probe, model.input, model.crossentropy, image, label)
    if i%100 == 0:
        loss = evaluate((test_images, test_labels), model.input, model.crossentropy)
        print("step {}, loss: {}".format(i, loss))
print("final loss: {}".format(evaluate((test_images, test_labels), model.input, model.crossentropy)))

contact

If you have any questions, comments, concerns, or wish to collaborate, please email Cameron Ryan.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

floral-1.0.3.tar.gz (10.7 MB view details)

Uploaded Source

Built Distribution

floral-1.0.3-py3-none-any.whl (10.8 MB view details)

Uploaded Python 3

File details

Details for the file floral-1.0.3.tar.gz.

File metadata

  • Download URL: floral-1.0.3.tar.gz
  • Upload date:
  • Size: 10.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.1

File hashes

Hashes for floral-1.0.3.tar.gz
Algorithm Hash digest
SHA256 9b9dab20c7ca773e5ededc1f71073fb623234d32465164d2dfeae580d2902b68
MD5 49d3121d1ec61154e2bd109977f901c3
BLAKE2b-256 fd480920f5b21624c5245def89348fc7c2701c11d31d8e41323a6169c02d6107

See more details on using hashes here.

File details

Details for the file floral-1.0.3-py3-none-any.whl.

File metadata

  • Download URL: floral-1.0.3-py3-none-any.whl
  • Upload date:
  • Size: 10.8 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.1

File hashes

Hashes for floral-1.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 79f467770a5278f3c3397b5d36f90ec0f0b2ab041b245bf172130b9f083fee28
MD5 93d547190aef3696ddf33bbcf3be3ca8
BLAKE2b-256 ee69ee7b442722510725903f483fd5ed2fd99e3e628e7bf3ed093a74a1b991e9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page