Skip to main content

A Python package code snips package for faster pytorch-AI development

Project description

Codebook

Installation

pip install codejournal

Note:

This project was developed to code and train models faster. The code is clean and hackable. This is not a production ready project!

Features:

  • Easy loading and saving of models
  • WandB integration
  • Slack integration
  • Checkpoint tracking
  • Resuming training from checkpoints
  • Debug mode

TODO:

[] Adding schedulers support [] Refactoring [] Multi-GPU support

Example:

from codejournal.imports import *
from codejournal.modeling import *

import torchvision.models as models
from torchvision import datasets, transforms

os.environ["HUGGINGFACE_TOKEN"] = "" # Put your HuggingFace token here
os.environ["SLACK_WEBHOOK_URL"] = "" # Put your Slack webhook URL here

class Config(ConfigBase):
    resnet: int = 18
    pretrained: bool = True
    num_classes: int = 10

class ResNet(ModelBase):
    def __init__(self, config):
        super().__init__(config)
        self.model = getattr(models, f'resnet{self.config.resnet}')(pretrained=self.config.pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, self.config.num_classes)
        
    def forward(self, x):
        x = x.repeat(1,3,1,1)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}

    def validation_step(self, batch, batch_idx):
        x,y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = torch.sum(preds == y).item() / y.size(0)
        return {'loss': loss, 'acc': acc}
    
    def get_optimizer(self, trainer):
        # Override the default optimizer if needed
        return super().get_optimizer(trainer)
    
    def get_scheduler(self, optimizer, trainer):
        args = trainer.args # For configuring÷
        return super().get_scheduler(optimizer, trainer)


transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

train_dataset = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)
val_dataset = datasets.MNIST('./data', train=False, download=True,
                    transform=transform)

training_args = TrainerArgs(
    # Core Training Configuration
    batch_size=32,
    max_epochs=5,
    train_steps_per_epoch=800,
    val_steps_per_epoch=400,
    grad_accumulation_steps=1,
    lr=1e-5,
    optimizer="AdamW",
    optimizer_kwargs={},
    scheduler=None,
    scheduler_kwargs={},

    # Logging and Checkpointing
    log_every_n_steps=32,
    save_every_n_steps=100,
    n_best_checkpoints=3,  # Negative value for saving all checkpoints
    n_latest_checkpoints=2,  # Negative value for saving all checkpoints
    checkpoint_metric="loss",
    checkpoint_metric_type="val",
    checkpoint_metric_minimize=True,

    # Hardware and Precision
    device="mps",  # Auto inferred
    mixed_precision=False,
    grad_clip_norm=1.0,  # False for no clipping
    num_workers=0,

    # WandB Integration
    wandb_project="Demo",  # wandb login
    wandb_run_name="MNIST",
    wandb_run_id=None,
    wandb_resume="allow",
    wandb_kwargs={},
    disable_wandb=False,

    # Resume and Debugging
    resume_from_checkpoint=True,  # None for no resuming, True for resuming latest checkpoint, or path to checkpoint
    debug_mode=False,

    # Miscellaneous
    safe_dataloader=True,
    log_grad_norm=True,
    slack_notify=True if os.environ.get("SLACK_WEBHOOK_URL") else False,
    results_dir="results",
    val_data_shuffle=False,
)

config = Config()
model = ResNet(config)

trainer = Trainer(training_args)
trainer.train(model, train_dataset, val_dataset)
if os.environ.get("HUGGINGFACE_TOKEN"):
    model.push_to_hub("demo-resnet")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

codejournal-0.1.9.1.tar.gz (22.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

codejournal-0.1.9.1-py3-none-any.whl (23.8 kB view details)

Uploaded Python 3

File details

Details for the file codejournal-0.1.9.1.tar.gz.

File metadata

  • Download URL: codejournal-0.1.9.1.tar.gz
  • Upload date:
  • Size: 22.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.11

File hashes

Hashes for codejournal-0.1.9.1.tar.gz
Algorithm Hash digest
SHA256 4062b9d0a6a755b7139899a8a189ad385cbb3e42c3583b9cd107f4c151eb2d8e
MD5 24b9684dc8771e572e131569115b919d
BLAKE2b-256 7619c04735baf06e9fbe1751ceb3a00a66241303b1947bc79faf0c0a85dad858

See more details on using hashes here.

File details

Details for the file codejournal-0.1.9.1-py3-none-any.whl.

File metadata

  • Download URL: codejournal-0.1.9.1-py3-none-any.whl
  • Upload date:
  • Size: 23.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.11

File hashes

Hashes for codejournal-0.1.9.1-py3-none-any.whl
Algorithm Hash digest
SHA256 734bb2bbfc043b6483ee41be1a96c2ac8540b33cc9078e9ded279f552cf1eab2
MD5 4cccd0300899b4d801cf9a2301bd1748
BLAKE2b-256 b4f017f531ff0b5d538ed610ce06c1d48bc57e554318fbcd730188fdcde6c0b7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page