Skip to main content

A Python package for code organization and utilities

Project description

Codebook

Installation

pip install codejournal

Note:

This project was developed to code and train models faster. The code is clean and hackable. This is not a production ready project!

Features:

  • Easy loading and saving of models
  • WandB integration
  • Slack integration
  • Checkpoint tracking
  • Resuming training from checkpoints
  • Debug mode

TODO:

[] Adding schedulers support [] Refactoring [] Multi-GPU support

Example:

from codejournal.imports import * # All important imports, import at once
from codejournal.modeling import * # All modeling tools: ConfigBase, Trainer, TrainerArgs, ModelBase

import torchvision.models as models
from torchvision import datasets, transforms

# os.envor["SLACK_WEBHOOK_URL"] = ""

class Config(ConfigBase):
    resnet: int = 18
    pretrained: bool = True
    num_classes: int = 10

class ResNet(ModelBase):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.model = getattr(models, f'resnet{config.resnet}')(pretrained=config.pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, config.num_classes)
        
    def forward(self, x):
        x = x.repeat(1,3,1,1)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}

    def validation_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}


config = Config()
model = ResNet(config)

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)
val_dataset = datasets.MNIST('./data', train=False, download=True,
                    transform=transform)

training_args = TrainerArgs(
    # Core Training Configuration
    batch_size=32,
    max_epochs=11,
    train_steps_per_epoch=1000,
    val_steps_per_epoch=500,
    grad_accumulation_steps=1,
    lr=1e-5,
    optimizer="AdamW",
    optimizer_kwargs={},
    scheduler=None,
    scheduler_kwargs={},

    # Logging and Checkpointing
    log_every_n_steps=32,
    save_every_n_steps=100,
    n_best_checkpoints=3,  # Negative value for saving all checkpoints
    n_latest_checkpoints=2,  # Negative value for saving all checkpoints
    checkpoint_metric="loss",
    checkpoint_metric_type="val",
    checkpoint_metric_minimize=True,

    # Hardware and Precision
    device="mps",  # Auto inferred
    mixed_precision=False,
    grad_clip_norm=1.0,  # False for no clipping
    num_workers=0,

    # WandB Integration
    wandb_project="Demo",  # wandb login
    wandb_run_name="MNIST",
    wandb_run_id=None,
    wandb_resume="allow",
    wandb_kwargs={},
    disable_wandb=False,

    # Resume and Debugging
    resume_from_checkpoint=True,  # None for no resuming, True for resuming latest checkpoint, or path to checkpoint
    debug_mode=False,

    # Miscellaneous
    safe_dataloader=True,
    log_grad_norm=True,
    slack_notify=True,
    results_dir="results",
    val_data_shuffle=False,
)


trainer = Trainer(training_args)
trainer.train(model, train_dataset, val_dataset)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

codejournal-0.1.0.tar.gz (19.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

codejournal-0.1.0-py3-none-any.whl (20.7 kB view details)

Uploaded Python 3

File details

Details for the file codejournal-0.1.0.tar.gz.

File metadata

  • Download URL: codejournal-0.1.0.tar.gz
  • Upload date:
  • Size: 19.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for codejournal-0.1.0.tar.gz
Algorithm Hash digest
SHA256 56375fcc9247f709de6ee7cca56066f64431c158a01010f60a59ba1f6439eabd
MD5 ca3210cc408414f17140bcd5bb3bb2c6
BLAKE2b-256 78e76c3a8dc6190ef6a6e43edd6af0bd3a702de4ce79f7df6c7012a40e704dc1

See more details on using hashes here.

File details

Details for the file codejournal-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: codejournal-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 20.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.9.21

File hashes

Hashes for codejournal-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8e7a3825e1d20a32abf90496a7be7d9622a5028887da7bdbadae139b66b343c3
MD5 ea0354e2f30c6b372b2cc0e327ed6836
BLAKE2b-256 0edd0a910e4fc4737ffe62de04fb46581690636209d66a0c798eb4ae1b36adcc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page