Online Deep Learning for river
Project description
DeepRiver is a Python library for online deep learning. DeepRivers ambition is to enable online machine learning for neural networks. It combines the river API with the capabilities of designing neural networks based on PyTorch.
💈 Installation
pip install deepriver
You can install the latest development version from GitHub as so:
pip install https://github.com/online-ml/river-torch --upgrade
Or, through SSH:
pip install git@github.com:online-ml/river-torch.git --upgrade
🍫 Quickstart
We build the development of neural networks on top of the river API and refer to the rivers design principles. The following example creates a simple MLP architecture based on PyTorch and incrementally predicts and trains on the website phishing dataset. For further examples check out the Documentation.
Classification
from river import datasets
from river import metrics
from river import preprocessing
from river import compose
from DeepRiver import classification
from torch import nn
from torch import optim
from torch import manual_seed
_ = manual_seed(0)
def build_torch_mlp_classifier(n_features): # build neural architecture
net = nn.Sequential(
nn.Linear(n_features, 5),
nn.Linear(5, 5),
nn.Linear(5, 5),
nn.Linear(5, 5),
nn.Linear(5, 1),
nn.Sigmoid()
)
return net
model = compose.Pipeline(
preprocessing.StandardScaler(),
classification.Classifier(build_fn=build_torch_mlp_classifier, loss_fn='bce', optimizer_fn=optim.Adam,
learning_rate=1e-3)
)
dataset = datasets.Phishing()
metric = metrics.Accuracy()
for x, y in dataset:
y_pred = model.predict_one(x) # make a prediction
metric = metric.update(y, y_pred) # update the metric
model = model.learn_one(x, y) # make the model learn
print(f'Accuracy: {metric.get()}')
Anomaly Detection
import math
from river import datasets, metrics
from DeepRiver.anomaly.nn_builder import get_fc_autoencoder
from DeepRiver.base import AutoencodedAnomalyDetector
from DeepRiver.utils import get_activation_fn
from torch import manual_seed, nn
_ = manual_seed(0)
def get_fully_conected_autoencoder(activation_fn="selu", dropout=0.5, n_features=3):
activation = get_activation_fn(activation_fn)
encoder = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(in_features=n_features, out_features=math.ceil(n_features / 2)),
activation(),
nn.Linear(in_features=math.ceil(n_features / 2), out_features=math.ceil(n_features / 4)),
activation(),
)
decoder = nn.Sequential(
nn.Linear(in_features=math.ceil(n_features / 4), out_features=math.ceil(n_features / 2)),
activation(),
nn.Linear(in_features=math.ceil(n_features / 2), out_features=n_features),
)
return encoder, decoder
if __name__ == '__main__':
dataset = datasets.CreditCard().take(5000)
metric = metrics.ROCAUC()
model = AutoencodedAnomalyDetector(build_fn=get_fully_conected_autoencoder, lr=0.01)
for x, y in dataset:
score = model.score_one(x)
metric.update(y_true=y, y_pred=score)
model.learn_one(x=x)
print(f'ROCAUC: {metric.get()}')
🏫 Affiliations
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for river_torch-0.0.12-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6b2e539516b318a917c291d9a816e045ad595bb78d15015bb854b76e30b6e399 |
|
MD5 | 56b3cdfb985356326718268c29dd7527 |
|
BLAKE2b-256 | 31f865e712af04068141d5c833c2f79c2a5673e7e616069cd2d0b8afdc433230 |