A scikit-learn compatible neural network library that wraps MLX.
Project description
SKLX
A scikit-learn compatible neural network library that wraps MLX. Highly inspired by skorch.
[!WARNING] This is still under development and non of the following examples actually work.
Examples
import numpy as np
from sklearn.datasets import make_classification
from mlx import nn
from sklx import NeuralNetClassifier
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super().__init__()
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, num_units)
self.output = nn.Linear(num_units, 2)
self.softmax = nn.Softmax(dim=-1)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.softmax(self.output(X))
return X
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
)
net.fit(X, y)
y_proba = net.predict_proba(X)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
sklx-0.0.1.tar.gz
(51.4 kB
view details)
Built Distribution
sklx-0.0.1-py3-none-any.whl
(5.2 kB
view details)
File details
Details for the file sklx-0.0.1.tar.gz
.
File metadata
- Download URL: sklx-0.0.1.tar.gz
- Upload date:
- Size: 51.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d4cd06bf44e0e86645c39a80ce55356b364eedbd347393903313ee94447b8175 |
|
MD5 | 0d05ef17e6dbe86ac12a960097be788d |
|
BLAKE2b-256 | 3132a9d53f5ff78863207cf0e7a6e7896a99ab867a30ead02c6666720b4feb7c |
File details
Details for the file sklx-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: sklx-0.0.1-py3-none-any.whl
- Upload date:
- Size: 5.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 374902458898ad7de61882dbe95a9aff9c894086026f618ae34db1df83d6468d |
|
MD5 | 7280712613491c7b3d674cfbb9714e94 |
|
BLAKE2b-256 | 556f818c54089f3dc0a835bee13227ef4fb8f564d00729c8ebb6495db20110c1 |