Deep learning framework built with numpy
Project description
NeuralFlow
Deep learning framework built with numpy (cupy)
This version supports cuda 11.x ver
install
$ git clone git clone https://github.com/augustinLib/neuralflow.git
or
# cpu-only
$ pip install neuralflow-cpu
# gpu (cuda 11.x)
$ pip install neuralflow
Quick guide
you can build model like this,
from neuralflow.model import Model, DenseLayer, ConvLayer, MaxPoolingLayer
from neuralflow.function_class import ReLU
model = Model(
DenseLayer(784, 50),
ReLU(),
DenseLayer(50, 10)
)
conv_model = Model(
ConvLayer(input_channel = 1, output_channel = 30, kernel_size = 5, stride = 1, padding=0),
ReLU(),
MaxPoolingLayer(kernel_size=2, stride=2),
DenseLayer(4320, 100),
ReLU(),
DenseLayer(100, 10)
)
and the training proceeds as follows.
from neuralflow.function_class import ReLU, CrossEntropyLoss
from neuralflow.optimizer import Adam
critic = CrossEntropyLoss()
optim = Adam()
pred = model(x)
loss = critic(pred, y)
model.backward(critic)
optim.update(model)
you can also train model with trainer
from neuralflow.trainer import ClassificationTrainer
from neuralflow.data import DataLoader
dataloader = DataLoader(train_data)
trainer = ClassificationTrainer(model,
critic,
optim,
epochs,
init_lr = 0.001)
trainer.train(dataloader)
when using gpu, set it as follows.
# using gpu
from neuralflow import config
config.GPU = True
# using cpu
from neuralflow import config
config.GPU = False
Structure
- neuralflow
- __init__.py
- data.py
- function.py
- function_class.py
- model.py
- optimizer.py
- trainer.py
- utils.py
- nlp
- utils.py
- epoch_notice
- send_message.py
- token_generator.py
- dataset
- test
- README.md
- .gitignore
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
neuralflow-cpu-0.0.4.tar.gz
(16.0 kB
view hashes)
Built Distribution
Close
Hashes for neuralflow_cpu-0.0.4-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 27b34773f8fe73e4eeac5678b202fd47efe1eccf68e089211681d1ce5cfb6ff1 |
|
MD5 | 9dd630a496de0c29bb9a6d9e2a1bde26 |
|
BLAKE2b-256 | 1e492d7f4c31b73153b4afc79e5a15285b25d97d88751259b7c5d61a219d7c80 |