A minimal version of fastai with only what's needed for the training loop
Project description
fastai_minima
A mimal version of fastai with the barebones needed to work with Pytorch
#all_slow
Install
pip install fastai_minima
How to use
This library is designed to bring in only the minimal needed from fastai to work with raw Pytorch. This includes:
- Learner
- Callbacks
- Optimizer
- DataLoaders (but not the
DataBlock
) - Metrics
Below we can find a very minimal example based off my Pytorch to fastai, Bridging the Gap article:
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
dset_train = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
dset_test = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(dset_train, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(dset_test, batch_size=4,
shuffle=False, num_workers=2)
Files already downloaded and verified
Files already downloaded and verified
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
criterion = nn.CrossEntropyLoss()
from torch import optim
from fastai_minima.optimizer import OptimWrapper
from fastai_minima.learner import Learner, DataLoaders
from fastai_minima.callback.training import CudaCallback, ProgressCallback
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, **kwargs))
dls = DataLoaders(trainloader, testloader)
learn = Learner(dls, Net(), loss_func=criterion, opt_func=opt_func)
# To use the GPU, do
# learn = Learner(dls, Net(), loss_func=criterion, opt_func=opt_func, cbs=[CudaCallback()])
learn.fit(2, lr=0.001)
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 2.269467 | 2.266472 | 01:20 |
1 | 1.876898 | 1.879593 | 01:21 |
/mnt/d/lib/python3.7/site-packages/torch/autograd/__init__.py:132: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
allow_unreachable=True) # allow_unreachable flag
If you want to do differential learning rates, when creating your splitter
to pass into fastai's Learner
you should utilize the convert_params
to make it compatable with Pytorch Optimizers:
def splitter(m): return convert_params([[m.a], [m.b]])
learn = Learner(..., splitter=splitter)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file fastai_minima-0.0.9.tar.gz
.
File metadata
- Download URL: fastai_minima-0.0.9.tar.gz
- Upload date:
- Size: 33.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.4.2 requests/2.25.1 setuptools/57.0.0 requests-toolbelt/0.9.1 tqdm/4.49.0 CPython/3.7.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d51732a8c25d837e31ba0a2ea9636f0bd648f259fb1418cb16f5d6bbe34a860a |
|
MD5 | 33a11904a5b6d76e48ea9733490e4e0f |
|
BLAKE2b-256 | 96e058cbe5a6bd879316d34643022b94dbe1e58b28d16be5932290bc42248b98 |
File details
Details for the file fastai_minima-0.0.9-py3-none-any.whl
.
File metadata
- Download URL: fastai_minima-0.0.9-py3-none-any.whl
- Upload date:
- Size: 31.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.4.2 requests/2.25.1 setuptools/57.0.0 requests-toolbelt/0.9.1 tqdm/4.49.0 CPython/3.7.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 34954784dc900df6365d0ef55dffe060adbe523c6d10677abc2e39f403c1a47f |
|
MD5 | b52e1380e947f6a92932b54ebb854f3f |
|
BLAKE2b-256 | 2f413254004020421df73be154c8e41ea1826213e15b595f1c8e72bc3f0d2749 |