Skip to main content

A Python package code snip package for faster pytorch-AI development

Project description

Codebook

Installation

pip install codejournal

Note:

This project was developed to code and train models faster. The code is clean and hackable. This is not a production ready project!

Features:

  • Easy loading and saving of models
  • WandB integration
  • Slack integration
  • Checkpoint tracking
  • Resuming training from checkpoints
  • Debug mode

TODO:

[] Adding schedulers support [] Refactoring [] Multi-GPU support

Example:

from codejournal.imports import * # All important imports, import at once
from codejournal.modeling import * # All modeling tools: ConfigBase, Trainer, TrainerArgs, ModelBase

import torchvision.models as models
from torchvision import datasets, transforms

# os.envor["SLACK_WEBHOOK_URL"] = ""

class Config(ConfigBase):
    resnet: int = 18
    pretrained: bool = True
    num_classes: int = 10

class ResNet(ModelBase):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.model = getattr(models, f'resnet{config.resnet}')(pretrained=config.pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, config.num_classes)
        
    def forward(self, x):
        x = x.repeat(1,3,1,1)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}

    def validation_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}


config = Config()
model = ResNet(config)

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)
val_dataset = datasets.MNIST('./data', train=False, download=True,
                    transform=transform)

training_args = TrainerArgs(
    # Core Training Configuration
    batch_size=32,
    max_epochs=11,
    train_steps_per_epoch=1000,
    val_steps_per_epoch=500,
    grad_accumulation_steps=1,
    lr=1e-5,
    optimizer="AdamW",
    optimizer_kwargs={},
    scheduler=None,
    scheduler_kwargs={},

    # Logging and Checkpointing
    log_every_n_steps=32,
    save_every_n_steps=100,
    n_best_checkpoints=3,  # Negative value for saving all checkpoints
    n_latest_checkpoints=2,  # Negative value for saving all checkpoints
    checkpoint_metric="loss",
    checkpoint_metric_type="val",
    checkpoint_metric_minimize=True,

    # Hardware and Precision
    device="mps",  # Auto inferred
    mixed_precision=False,
    grad_clip_norm=1.0,  # False for no clipping
    num_workers=0,

    # WandB Integration
    wandb_project="Demo",  # wandb login
    wandb_run_name="MNIST",
    wandb_run_id=None,
    wandb_resume="allow",
    wandb_kwargs={},
    disable_wandb=False,

    # Resume and Debugging
    resume_from_checkpoint=True,  # None for no resuming, True for resuming latest checkpoint, or path to checkpoint
    debug_mode=False,

    # Miscellaneous
    safe_dataloader=True,
    log_grad_norm=True,
    slack_notify=True,
    results_dir="results",
    val_data_shuffle=False,
)


trainer = Trainer(training_args)
trainer.train(model, train_dataset, val_dataset)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

codejournal-0.1.1.tar.gz (19.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

codejournal-0.1.1-py3-none-any.whl (20.7 kB view details)

Uploaded Python 3

File details

Details for the file codejournal-0.1.1.tar.gz.

File metadata

  • Download URL: codejournal-0.1.1.tar.gz
  • Upload date:
  • Size: 19.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for codejournal-0.1.1.tar.gz
Algorithm Hash digest
SHA256 ee94513689e9d988a394a3f53920bb5d894ef03ab51146e748dd83fcd8e2ad4b
MD5 6a617fafc5f2aa4cbf05a7b819efa2d5
BLAKE2b-256 6b9084143854fe663ba2fac328187bc80021391f1caf1131dd0bf256b1afb055

See more details on using hashes here.

File details

Details for the file codejournal-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: codejournal-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 20.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for codejournal-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 025ef7b24df316203967d0624ce84dd75266b9859b5f6e635e1b0613c63a3680
MD5 201083441fd9205af81560633b728e67
BLAKE2b-256 574221a2ad82ae4426acb8c3d2b9194e5f9949bd92b5c48d5b4033e2f92fdb23

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page