Skip to main content

A Python package hackable code snips package for faster pytorch-AI development

Project description

Codebook

Installation

pip install codejournal

Note:

This project was developed to code and train models faster. The code is clean and hackable. This is not a production ready project!

Features:

  • Easy loading and saving of models
  • WandB integration
  • Slack integration
  • Checkpoint tracking
  • Resuming training from checkpoints
  • Debug mode

TODO:

[] Adding schedulers support [] Refactoring [] Multi-GPU support

Example:

from codejournal.imports import *
from codejournal.modeling import *

import torchvision.models as models
from torchvision import datasets, transforms

os.environ["HUGGINGFACE_TOKEN"] = "" # Put your HuggingFace token here
os.environ["SLACK_WEBHOOK_URL"] = "" # Put your Slack webhook URL here

class Config(ConfigBase):
    resnet: int = 18
    pretrained: bool = True
    num_classes: int = 10

class ResNet(ModelBase):
    def __init__(self, config):
        super().__init__(config)
        self.model = getattr(models, f'resnet{self.config.resnet}')(pretrained=self.config.pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, self.config.num_classes)
        
    def forward(self, x):
        x = x.repeat(1,3,1,1)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}

    def validation_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}
    
    def get_optimizer(self, trainer):
        # Override the default optimizer if needed
        return super().get_optimizer(trainer)
    
    def get_scheduler(self, optimizer, trainer):
        args = trainer.args # For configuring÷
        return super().get_scheduler(optimizer, trainer)


transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)
val_dataset = datasets.MNIST('./data', train=False, download=True,
                    transform=transform)

training_args = TrainerArgs(
    # Core Training Configuration
    batch_size=32,
    max_epochs=5,
    train_steps_per_epoch=800,
    val_steps_per_epoch=400,
    grad_accumulation_steps=1,
    lr=1e-5,
    optimizer="AdamW",
    optimizer_kwargs={},
    scheduler=None,
    scheduler_kwargs={},

    # Logging and Checkpointing
    log_every_n_steps=32,
    save_every_n_steps=100,
    n_best_checkpoints=3,  # Negative value for saving all checkpoints
    n_latest_checkpoints=2,  # Negative value for saving all checkpoints
    checkpoint_metric="loss",
    checkpoint_metric_type="val",
    checkpoint_metric_minimize=True,

    # Hardware and Precision
    device="mps",  # Auto inferred
    mixed_precision=False,
    grad_clip_norm=1.0,  # False for no clipping
    num_workers=0,

    # WandB Integration
    wandb_project="Demo",  # wandb login
    wandb_run_name="MNIST",
    wandb_run_id=None,
    wandb_resume="allow",
    wandb_kwargs={},
    disable_wandb=False,

    # Resume and Debugging
    resume_from_checkpoint=True,  # None for no resuming, True for resuming latest checkpoint, or path to checkpoint
    debug_mode=False,

    # Miscellaneous
    safe_dataloader=True,
    log_grad_norm=True,
    slack_notify=True if os.environ.get("SLACK_WEBHOOK_URL") else False,
    results_dir="results",
    val_data_shuffle=False,
)

config = Config()
model = ResNet(config)

trainer = Trainer(training_args)
trainer.train(model, train_dataset, val_dataset)
if os.environ.get("HUGGINGFACE_TOKEN"):
    model.push_to_hub("demo-resnet")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

codejournal-0.2.0.1.tar.gz (23.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

codejournal-0.2.0.1-py3-none-any.whl (24.9 kB view details)

Uploaded Python 3

File details

Details for the file codejournal-0.2.0.1.tar.gz.

File metadata

  • Download URL: codejournal-0.2.0.1.tar.gz
  • Upload date:
  • Size: 23.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.11

File hashes

Hashes for codejournal-0.2.0.1.tar.gz
Algorithm Hash digest
SHA256 07e93ca18ed569652f70aaa1176f7815a05fc026ac40fe0f88bcc675f716b73f
MD5 9c2fa29337781fb74c33afc6fcc107ed
BLAKE2b-256 639188a9089957131a5e1d96438e0a4ea0a1167f7b9ac7b7253cbeb9e7f99884

See more details on using hashes here.

File details

Details for the file codejournal-0.2.0.1-py3-none-any.whl.

File metadata

  • Download URL: codejournal-0.2.0.1-py3-none-any.whl
  • Upload date:
  • Size: 24.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.11

File hashes

Hashes for codejournal-0.2.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ffe89a1d78f66a7af774392051b069b93c7cb60cd682ea5f6d627d196ba83489
MD5 ded2eca84a79fbe85c5d69edce88c959
BLAKE2b-256 79451f38598b4f88e481fa22e9e9bc93544bfcb48af2cdc3372f4890e5776344

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page