Skip to main content

A Python package code snips package for faster pytorch-AI development

Project description

Codebook

Installation

pip install codejournal

Note:

This project was developed to code and train models faster. The code is clean and hackable. This is not a production ready project!

Features:

  • Easy loading and saving of models
  • WandB integration
  • Slack integration
  • Checkpoint tracking
  • Resuming training from checkpoints
  • Debug mode

TODO:

[] Adding schedulers support [] Refactoring [] Multi-GPU support

Example:

from codejournal.imports import *
from codejournal.modeling import *

import torchvision.models as models
from torchvision import datasets, transforms

os.environ["HUGGINGFACE_TOKEN"] = "" # Put your HuggingFace token here
os.environ["SLACK_WEBHOOK_URL"] = "" # Put your Slack webhook URL here

class Config(ConfigBase):
    resnet: int = 18
    pretrained: bool = True
    num_classes: int = 10

class ResNet(ModelBase):
    def __init__(self, config):
        super().__init__(config)
        self.model = getattr(models, f'resnet{self.config.resnet}')(pretrained=self.config.pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, self.config.num_classes)
        
    def forward(self, x):
        x = x.repeat(1,3,1,1)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}

    def validation_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}
    
    def get_optimizer(self, trainer):
        # Override the default optimizer if needed
        return super().get_optimizer(trainer)
    
    def get_scheduler(self, optimizer, trainer):
        args = trainer.args # For configuring÷
        return super().get_scheduler(optimizer, trainer)


transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)
val_dataset = datasets.MNIST('./data', train=False, download=True,
                    transform=transform)

training_args = TrainerArgs(
    # Core Training Configuration
    batch_size=32,
    max_epochs=5,
    train_steps_per_epoch=800,
    val_steps_per_epoch=400,
    grad_accumulation_steps=1,
    lr=1e-5,
    optimizer="AdamW",
    optimizer_kwargs={},
    scheduler=None,
    scheduler_kwargs={},

    # Logging and Checkpointing
    log_every_n_steps=32,
    save_every_n_steps=100,
    n_best_checkpoints=3,  # Negative value for saving all checkpoints
    n_latest_checkpoints=2,  # Negative value for saving all checkpoints
    checkpoint_metric="loss",
    checkpoint_metric_type="val",
    checkpoint_metric_minimize=True,

    # Hardware and Precision
    device="mps",  # Auto inferred
    mixed_precision=False,
    grad_clip_norm=1.0,  # False for no clipping
    num_workers=0,

    # WandB Integration
    wandb_project="Demo",  # wandb login
    wandb_run_name="MNIST",
    wandb_run_id=None,
    wandb_resume="allow",
    wandb_kwargs={},
    disable_wandb=False,

    # Resume and Debugging
    resume_from_checkpoint=True,  # None for no resuming, True for resuming latest checkpoint, or path to checkpoint
    debug_mode=False,

    # Miscellaneous
    safe_dataloader=True,
    log_grad_norm=True,
    slack_notify=True if os.environ.get("SLACK_WEBHOOK_URL") else False,
    results_dir="results",
    val_data_shuffle=False,
)

config = Config()
model = ResNet(config)

trainer = Trainer(training_args)
trainer.train(model, train_dataset, val_dataset)
if os.environ.get("HUGGINGFACE_TOKEN"):
    model.push_to_hub("demo-resnet")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

codejournal-0.1.5.tar.gz (22.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

codejournal-0.1.5-py3-none-any.whl (23.4 kB view details)

Uploaded Python 3

File details

Details for the file codejournal-0.1.5.tar.gz.

File metadata

  • Download URL: codejournal-0.1.5.tar.gz
  • Upload date:
  • Size: 22.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.11

File hashes

Hashes for codejournal-0.1.5.tar.gz
Algorithm Hash digest
SHA256 e7cecaa74861debac5195c9fd2a5ae5eaf25ca10bc6aa1abb996fd246ee108a5
MD5 95f489ba6afc434e533049da1997eb66
BLAKE2b-256 f7118bd5cf1710bd6d003d7ba0406e7407aba23e4823292abbe065deea4498cf

See more details on using hashes here.

File details

Details for the file codejournal-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: codejournal-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 23.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.11

File hashes

Hashes for codejournal-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 4ded673e15d6c67e42cb6d571aac88365a598ea697204500ef05be26f3522a0a
MD5 7f4a79acc5d6d340c42c4ab8cac3bee7
BLAKE2b-256 9b8c97a3078e336993a9b87674af8098de53f2a2f7cf32ea19c42c14b77d6128

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page