A Python package hackable code snips package for faster pytorch-AI development
Project description
Codebook
Installation
pip install codejournal
Note:
This project was developed to code and train models faster. The code is clean and hackable. This is not a production ready project!
Features:
- Easy loading and saving of models
- WandB integration
- Slack integration
- Checkpoint tracking
- Resuming training from checkpoints
- Debug mode
TODO:
[] Adding schedulers support [] Refactoring [] Multi-GPU support
Example:
from codejournal.imports import *
from codejournal.modeling import *
import torchvision.models as models
from torchvision import datasets, transforms
os.environ["HUGGINGFACE_TOKEN"] = "" # Put your HuggingFace token here
os.environ["SLACK_WEBHOOK_URL"] = "" # Put your Slack webhook URL here
class Config(ConfigBase):
resnet: int = 18
pretrained: bool = True
num_classes: int = 10
class ResNet(ModelBase):
def __init__(self, config):
super().__init__(config)
self.model = getattr(models, f'resnet{self.config.resnet}')(pretrained=self.config.pretrained)
self.model.fc = nn.Linear(self.model.fc.in_features, self.config.num_classes)
def forward(self, x):
x = x.repeat(1,3,1,1)
return self.model(x)
def training_step(self, batch, batch_idx):
x,y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
preds = torch.argmax(y_hat, dim=1)
acc = torch.sum(preds == y).item() / y.size(0)
return {'loss': loss, 'acc': acc}
def validation_step(self, batch, batch_idx):
x,y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
preds = torch.argmax(y_hat, dim=1)
acc = torch.sum(preds == y).item() / y.size(0)
return {'loss': loss, 'acc': acc}
def get_optimizer(self, trainer):
# Override the default optimizer if needed
return super().get_optimizer(trainer)
def get_scheduler(self, optimizer, trainer):
args = trainer.args # For configuring÷
return super().get_scheduler(optimizer, trainer)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True,
transform=transform)
val_dataset = datasets.MNIST('./data', train=False, download=True,
transform=transform)
training_args = TrainerArgs(
# Core Training Configuration
batch_size=32,
max_epochs=5,
train_steps_per_epoch=800,
val_steps_per_epoch=400,
grad_accumulation_steps=1,
lr=1e-5,
optimizer="AdamW",
optimizer_kwargs={},
scheduler=None,
scheduler_kwargs={},
# Logging and Checkpointing
log_every_n_steps=32,
save_every_n_steps=100,
n_best_checkpoints=3, # Negative value for saving all checkpoints
n_latest_checkpoints=2, # Negative value for saving all checkpoints
checkpoint_metric="loss",
checkpoint_metric_type="val",
checkpoint_metric_minimize=True,
# Hardware and Precision
device="mps", # Auto inferred
mixed_precision=False,
grad_clip_norm=1.0, # False for no clipping
num_workers=0,
# WandB Integration
wandb_project="Demo", # wandb login
wandb_run_name="MNIST",
wandb_run_id=None,
wandb_resume="allow",
wandb_kwargs={},
disable_wandb=False,
# Resume and Debugging
resume_from_checkpoint=True, # None for no resuming, True for resuming latest checkpoint, or path to checkpoint
debug_mode=False,
# Miscellaneous
safe_dataloader=True,
log_grad_norm=True,
slack_notify=True if os.environ.get("SLACK_WEBHOOK_URL") else False,
results_dir="results",
val_data_shuffle=False,
)
config = Config()
model = ResNet(config)
trainer = Trainer(training_args)
trainer.train(model, train_dataset, val_dataset)
if os.environ.get("HUGGINGFACE_TOKEN"):
model.push_to_hub("demo-resnet")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
codejournal-0.2.0.1.tar.gz
(23.5 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file codejournal-0.2.0.1.tar.gz.
File metadata
- Download URL: codejournal-0.2.0.1.tar.gz
- Upload date:
- Size: 23.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
07e93ca18ed569652f70aaa1176f7815a05fc026ac40fe0f88bcc675f716b73f
|
|
| MD5 |
9c2fa29337781fb74c33afc6fcc107ed
|
|
| BLAKE2b-256 |
639188a9089957131a5e1d96438e0a4ea0a1167f7b9ac7b7253cbeb9e7f99884
|
File details
Details for the file codejournal-0.2.0.1-py3-none-any.whl.
File metadata
- Download URL: codejournal-0.2.0.1-py3-none-any.whl
- Upload date:
- Size: 24.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.11.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ffe89a1d78f66a7af774392051b069b93c7cb60cd682ea5f6d627d196ba83489
|
|
| MD5 |
ded2eca84a79fbe85c5d69edce88c959
|
|
| BLAKE2b-256 |
79451f38598b4f88e481fa22e9e9bc93544bfcb48af2cdc3372f4890e5776344
|