Skip to main content

Engine of OpenMMLab projects

Project description

 
OpenMMLab website HOT      OpenMMLab platform TRY IT OUT
 

PyPI - Python Version pytorch PyPI license

Introduction | Installation | Get Started | 📘Documentation | 🤔Reporting Issues

English | 简体中文

What's New

v0.10.6 was released on 2025-01-13.

Highlights:

  • Support custom artifact_location in MLflowVisBackend #1505
  • Enable exclude_frozen_parameters for DeepSpeedEngine._zero3_consolidated_16bit_state_dict #1517

Read Changelog for more details.

Introduction

MMEngine is a foundational library for training deep learning models based on PyTorch. It serves as the training engine of all OpenMMLab codebases, which support hundreds of algorithms in various research areas. Moreover, MMEngine is also generic to be applied to non-OpenMMLab projects. Its highlights are as follows:

Integrate mainstream large-scale model training frameworks

Supports a variety of training strategies

Provides a user-friendly configuration system

Covers mainstream training monitoring platforms

Installation

Supported PyTorch Versions
MMEngine PyTorch Python
main >=1.6 <=2.1 >=3.8, <=3.11
>=0.9.0, <=0.10.4 >=1.6 <=2.1 >=3.8, <=3.11

Before installing MMEngine, please ensure that PyTorch has been successfully installed following the official guide.

Install MMEngine

pip install -U openmim
mim install mmengine

Verify the installation

python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'

Get Started

Taking the training of a ResNet-50 model on the CIFAR-10 dataset as an example, we will use MMEngine to build a complete, configurable training and validation process in less than 80 lines of code.

Build Models

First, we need to define a model which 1) inherits from BaseModel and 2) accepts an additional argument mode in the forward method, in addition to those arguments related to the dataset.

  • During training, the value of mode is "loss", and the forward method should return a dict containing the key "loss".
  • During validation, the value of mode is "predict", and the forward method should return results containing both predictions and labels.
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
Build Datasets

Next, we need to create Datasets and DataLoaders for training and validation. In this case, we simply use built-in datasets supported in TorchVision.

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))
val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))
Build Metrics

To validate and test the model, we need to define a Metric called accuracy to evaluate the model. This metric needs to inherit from BaseMetric and implements the process and compute_metrics methods.

from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # Save the results of a batch to `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })
    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # Returns a dictionary with the results of the evaluated metrics,
        # where the key is the name of the metric
        return dict(accuracy=100 * total_correct / total_size)
Build a Runner

Finally, we can construct a Runner with previously defined Model, DataLoader, and Metrics, with some other configs, as shown below.

from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    model=MMResNet50(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    # a wrapper to execute back propagation and gradient update, etc.
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # set some training configs like epochs
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
Launch Training
runner.train()

Learn More

Tutorials
Advanced tutorials
Examples
Common Usage
Design
Migration guide

Contributing

We appreciate all contributions to improve MMEngine. Please refer to CONTRIBUTING.md for the contributing guideline.

Citation

If you find this project useful in your research, please consider cite:

@article{mmengine2022,
  title   = {{MMEngine}: OpenMMLab Foundational Library for Training Deep Learning Models},
  author  = {MMEngine Contributors},
  howpublished = {\url{https://github.com/open-mmlab/mmengine}},
  year={2022}
}

License

This project is released under the Apache 2.0 license.

Ecosystem

Projects in OpenMMLab

  • MIM: MIM installs OpenMMLab packages.
  • MMCV: OpenMMLab foundational library for computer vision.
  • MMEval: A unified evaluation library for multiple machine learning libraries.
  • MMPreTrain: OpenMMLab pre-training toolbox and benchmark.
  • MMagic: OpenMMLab Advanced, Generative and Intelligent Creation toolbox.
  • MMDetection: OpenMMLab detection toolbox and benchmark.
  • MMYOLO: OpenMMLab YOLO series toolbox and benchmark.
  • MMDetection3D: OpenMMLab's next-generation platform for general 3D object detection.
  • MMRotate: OpenMMLab rotated object detection toolbox and benchmark.
  • MMTracking: OpenMMLab video perception toolbox and benchmark.
  • MMPose: OpenMMLab pose estimation toolbox and benchmark.
  • MMSegmentation: OpenMMLab semantic segmentation toolbox and benchmark.
  • MMOCR: OpenMMLab text detection, recognition, and understanding toolbox.
  • MMHuman3D: OpenMMLab 3D human parametric model toolbox and benchmark.
  • MMSelfSup: OpenMMLab self-supervised learning toolbox and benchmark.
  • MMFewShot: OpenMMLab fewshot learning toolbox and benchmark.
  • MMAction2: OpenMMLab's next-generation action understanding toolbox and benchmark.
  • MMFlow: OpenMMLab optical flow toolbox and benchmark.
  • MMDeploy: OpenMMLab model deployment framework.
  • MMRazor: OpenMMLab model compression toolbox and benchmark.
  • Playground: A central hub for gathering and showcasing amazing projects built upon OpenMMLab.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mmengine-0.10.7.tar.gz (378.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mmengine-0.10.7-py3-none-any.whl (452.7 kB view details)

Uploaded Python 3

File details

Details for the file mmengine-0.10.7.tar.gz.

File metadata

  • Download URL: mmengine-0.10.7.tar.gz
  • Upload date:
  • Size: 378.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.13

File hashes

Hashes for mmengine-0.10.7.tar.gz
Algorithm Hash digest
SHA256 d20ffcc31127567e53dceff132612a87f0081de06cbb7ab2bdb7439125a69225
MD5 cd55a60f6100a9a7c09f992b551e7942
BLAKE2b-256 1714959360bbd8374e23fc1b720906999add16a3ac071a501636db12c5861ff5

See more details on using hashes here.

File details

Details for the file mmengine-0.10.7-py3-none-any.whl.

File metadata

  • Download URL: mmengine-0.10.7-py3-none-any.whl
  • Upload date:
  • Size: 452.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.13

File hashes

Hashes for mmengine-0.10.7-py3-none-any.whl
Algorithm Hash digest
SHA256 262ac976a925562f78cd5fd14dd1bc9b680ed0aa81f0d85b723ef782f99c54ee
MD5 7cd858a8dbb12c8671a0a6e5ecdde98a
BLAKE2b-256 988ef98332248aad102511bea4ae19c0ddacd2f0a994f3ca4c82b7a369e0af8b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page