Skip to main content

A Python package hackable code snips package for faster pytorch-AI development

Project description

Codebook

Installation

pip install codejournal

Note:

This project was developed to code and train models faster. The code is clean and hackable. This is not a production ready project!

Features:

  • Easy loading and saving of models
  • WandB integration
  • Slack integration
  • Checkpoint tracking
  • Resuming training from checkpoints
  • Debug mode

TODO:

[] Adding schedulers support [] Refactoring [] Multi-GPU support

Example:

from codejournal.imports import *
from codejournal.modeling import *

import torchvision.models as models
from torchvision import datasets, transforms

os.environ["HUGGINGFACE_TOKEN"] = "" # Put your HuggingFace token here
os.environ["SLACK_WEBHOOK_URL"] = "" # Put your Slack webhook URL here

class Config(ConfigBase):
    resnet: int = 18
    pretrained: bool = True
    num_classes: int = 10

class ResNet(ModelBase):
    def __init__(self, config):
        super().__init__(config)
        self.model = getattr(models, f'resnet{self.config.resnet}')(pretrained=self.config.pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, self.config.num_classes)
        
    def forward(self, x):
        x = x.repeat(1,3,1,1)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}

    def validation_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}
    
    def get_optimizer(self, trainer):
        # Override the default optimizer if needed
        return super().get_optimizer(trainer)
    
    def get_scheduler(self, optimizer, trainer):
        args = trainer.args # For configuring÷
        return super().get_scheduler(optimizer, trainer)


transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)
val_dataset = datasets.MNIST('./data', train=False, download=True,
                    transform=transform)

training_args = TrainerArgs(
    # Core Training Configuration
    batch_size=32,
    max_epochs=5,
    train_steps_per_epoch=800,
    val_steps_per_epoch=400,
    grad_accumulation_steps=1,
    lr=1e-5,
    optimizer="AdamW",
    optimizer_kwargs={},
    scheduler=None,
    scheduler_kwargs={},

    # Logging and Checkpointing
    log_every_n_steps=32,
    save_every_n_steps=100,
    n_best_checkpoints=3,  # Negative value for saving all checkpoints
    n_latest_checkpoints=2,  # Negative value for saving all checkpoints
    checkpoint_metric="loss",
    checkpoint_metric_type="val",
    checkpoint_metric_minimize=True,

    # Hardware and Precision
    device="mps",  # Auto inferred
    mixed_precision=False,
    grad_clip_norm=1.0,  # False for no clipping
    num_workers=0,

    # WandB Integration
    wandb_project="Demo",  # wandb login
    wandb_run_name="MNIST",
    wandb_run_id=None,
    wandb_resume="allow",
    wandb_kwargs={},
    disable_wandb=False,

    # Resume and Debugging
    resume_from_checkpoint=True,  # None for no resuming, True for resuming latest checkpoint, or path to checkpoint
    debug_mode=False,

    # Miscellaneous
    safe_dataloader=True,
    log_grad_norm=True,
    slack_notify=True if os.environ.get("SLACK_WEBHOOK_URL") else False,
    results_dir="results",
    val_data_shuffle=False,
)

config = Config()
model = ResNet(config)

trainer = Trainer(training_args)
trainer.train(model, train_dataset, val_dataset)
if os.environ.get("HUGGINGFACE_TOKEN"):
    model.push_to_hub("demo-resnet")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

codejournal-0.1.9.3.tar.gz (22.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

codejournal-0.1.9.3-py3-none-any.whl (23.8 kB view details)

Uploaded Python 3

File details

Details for the file codejournal-0.1.9.3.tar.gz.

File metadata

  • Download URL: codejournal-0.1.9.3.tar.gz
  • Upload date:
  • Size: 22.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.11

File hashes

Hashes for codejournal-0.1.9.3.tar.gz
Algorithm Hash digest
SHA256 038d29c4daf04e2c823b94ca2d54b3545d579d63b09bd41b2f9262db258b0af0
MD5 f847144f1184d7afef08e39d7a3a684d
BLAKE2b-256 6546ac9d727b6713870eab488bcc143bcf08ac53dc8d10e41be2878c5a70c903

See more details on using hashes here.

File details

Details for the file codejournal-0.1.9.3-py3-none-any.whl.

File metadata

  • Download URL: codejournal-0.1.9.3-py3-none-any.whl
  • Upload date:
  • Size: 23.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.11

File hashes

Hashes for codejournal-0.1.9.3-py3-none-any.whl
Algorithm Hash digest
SHA256 114db2044667c285b11772ab890a2ab22c8aab579c304de79354925c63ad3b85
MD5 0a53d06402d6de01009ae5c1227bae7a
BLAKE2b-256 f24ded18a208356e4ae1f52ae0d416966860f4def8e398bbd1f16aa6e9e95d06

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page