Skip to main content

a simple autograd library

Project description

Simple Torch

implement by numpy

autograd

tensor

Implement: Basic computing between tensors,also,recording depedency between tensors and grad of tensors.

function

Implement: Activation function.There is only tanh currently.

parameter

Implement: Quick create random tensors with requires_grad=True.

module

Implement: Recording of all parameters

optim

Implement: Optimizing for module.

tests

Test for autograd function

example

fizz_buzz

import numpy as np
from typing import List
from autograd import Tensor, Parameter, Module
from autograd.optim import SGD
from autograd.function import tanh

"""
print the numbers 1 to 100,
except  
    if the number is divisible by 3 print "fizz"
    if the number is divisible by 5 print "fizz"
    if the number is divisible by 15 print "fizz_buzz"

"""


def binary_encode(x: int) -> List[int]:
    return [x >> i & 1 for i in range(10)]


def fizz_buzz_encode(x: int) -> List[int]:
    if x % 15 == 0:
        return [0, 0, 0, 1]
    elif x % 5 == 0:
        return [0, 0, 1, 0]
    elif x % 3 == 0:
        return [0, 1, 0, 0]
    else:
        return [1, 0, 0, 0]


x_train = Tensor([binary_encode(x) for x in range(101, 1024)])
y_train = Tensor([fizz_buzz_encode(x) for x in range(101, 1024)])


class FizzBuzzModule(Module):
    def __init__(self, num_hidden: int = 50) -> None:
        self.w1 = Parameter(10, num_hidden)
        self.b1 = Parameter(num_hidden)

        self.w2 = Parameter(num_hidden, 4)
        self.b2 = Parameter(4)

    def predict(self, in_puts: Tensor):
        # inputs (batch_size,10)
        x1 = inputs @ self.w1 + self.b1  # (batch_size,num_hidden)
        x2 = tanh(x1)
        x3 = x2 @ self.w2 + self.b2  # (batch_size,4)
        return x3


optimizer = SGD(lr=0.001)
batch_size = 32
module = FizzBuzzModule()

starts = np.arange(0, x_train.shape[0], batch_size)
for epoch in range(10000):
    epoch_loss = 0.0

    np.random.shuffle(starts)
    for start in starts:
        end = start + batch_size

        module.zero_grad()
        inputs = x_train[start:end]

        predicted = module.predict(inputs)
        actual = y_train[start:end]
        errors = predicted - actual
        loss = (errors * errors).sum()

        loss.backward()
        epoch_loss += loss.data

        optimizer.step(module)
    print(epoch, epoch_loss)

num_correct = 0
for x in range(1, 101):
    inputs = Tensor([binary_encode(x)])
    predicted = module.predict(inputs)[0]
    predicted_idx = np.argmax(predicted.data)
    actual_idx = np.argmax(fizz_buzz_encode(x))
    labels = [str(x), "fizz", "buzz", "fizz_buzz"]

    if predicted_idx == actual_idx:
        num_correct += 1
    print(x, labels[predicted_idx], labels[actual_idx])

print(num_correct,"/100")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

Torch-Yottaxx-0.1.3.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

Torch_Yottaxx-0.1.3-py3-none-any.whl (11.3 kB view details)

Uploaded Python 3

File details

Details for the file Torch-Yottaxx-0.1.3.tar.gz.

File metadata

  • Download URL: Torch-Yottaxx-0.1.3.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for Torch-Yottaxx-0.1.3.tar.gz
Algorithm Hash digest
SHA256 beda0d6763b0acbaaed7af3b0c9791be70988d46bac4b859df3c3e54b106dbf2
MD5 4472fb6c8a639821d459826c41cca280
BLAKE2b-256 0d5cb89bd6ece38abc77553d39026c0bc22a8515b14abe144fc35ab17402d5eb

See more details on using hashes here.

File details

Details for the file Torch_Yottaxx-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: Torch_Yottaxx-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 11.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/49.2.0 requests-toolbelt/0.9.1 tqdm/4.42.1 CPython/3.7.6

File hashes

Hashes for Torch_Yottaxx-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 a19386c136e2f879ceb30b6552d1f6c1ea26979cfa456f5452f11a2afeccbacc
MD5 f0337cea1af10b94d292b2e0f18d8031
BLAKE2b-256 2e473bc804b5efa7eeecdfe0decc30c12ed806042b14264cbc713f4c48a6f498

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page