Skip to main content

No project description provided

Project description

torchmix

The missing component library for PyTorch


Welcome to torchmix, a collection of PyTorch modules that aims to reduce boilerplate and improve code modularity.

Please note: torchmix is currently in development and has not been tested for production use. The API may change at any time.


Usage

To use torchmix, simply import the desired components:

import torchmix.nn as nn  # Wrapped version of torch.nn
from torchmix import (
    Add,
    Attach,
    AvgPool,
    ChannelMixer,
    Extract,
    PatchEmbed,
    PositionEmbed,
    PreNorm,
    Repeat,
    SelfAttention,
    Token,
)

You can simply compose this components to build more complex architecture, as shown in the following example:

vit_cls = nn.Sequential(
    Add(
        Attach(
            PatchEmbed(dim=1024),
            Token(dim=1024),
        ),
        PositionEmbed(
            seq_length=196 + 1,
            dim=1024,
        ),
    ),
    Repeat(
        nn.Sequential(
            PreNorm(
                ChannelMixer(
                    dim=1024,
                    expansion_factor=4,
                    act_layer=nn.GELU.partial(),
                ),
                dim=1024,
            ),
            PreNorm(
                SelfAttention(
                    dim=1024,
                    num_heads=8,
                    head_dim=64,
                ),
                dim=1024,
            ),
        )
    ),
    Extract(0),
)

vit_avg = nn.Sequential(
    Add(
        PatchEmbed(dim=1024),
        PositionEmbed(
            seq_length=196,
            dim=1024,
        ),
    ),
    Repeat(
        nn.Sequential(
            PreNorm(
                ChannelMixer(
                    dim=1024,
                    expansion_factor=4,
                    act_layer=nn.GELU.partial(),
                ),
                dim=1024,
            ),
            PreNorm(
                SelfAttention(
                    dim=1024,
                    num_heads=8,
                    head_dim=64,
                ),
                dim=1024,
            ),
        )
    ),
    AvgPool(),
)

Integration with Hydra

Reproducibility is important, so it is always a good idea to store the configurations of your models. However, manually writing the configurations for complex, deeply nested PyTorch modules can be tedious and result in code that is difficult to understand and maintain. This is because the parent class may need to accept and pass along the parameters of its children classes, leading to a large number of arguments and strong coupling between the parent and child classes.

torchmix simplifies this process by auto-magically generating the full configuration of a PyTorch module simply by instantiating it. This enables effortless integration with the hydra ecosystem, which allows for easy storage and management of module configurations.

To generate a configuration for a typical MLP using torchmix, for example, you can do the following:

from torchmix import nn

model = nn.Sequential(
    nn.Linear(1024, 4096),
    nn.Dropout(0.1),
    nn.GELU(),
    nn.Linear(4096, 1024),
    nn.Dropout(0.1),
)

You can then store the configuration in the hydra's ConfigStore using:

model.store(group="model", name="mlp")

Alternatively, you can export it to a YAML file if you want:

model.export("mlp.yaml")

This will generate the following configuration:

_target_: torchmix.nn.Sequential
_args_:
  - _target_: torchmix.nn.Linear
    in_features: 1024
    out_features: 4096
    bias: true
    device: null
    dtype: null
  - _target_: torchmix.nn.Dropout
    p: 0.1
    inplace: false
  - _target_: torchmix.nn.GELU
    approximate: none
  - _target_: torchmix.nn.Linear
    in_features: 4096
    out_features: 1024
    bias: true
    device: null
    dtype: null
  - _target_: torchmix.nn.Dropout
    p: 0.1
    inplace: false

You can always instantiate the actual PyTorch module from its configuration using hydra's instantiate function.

To create custom modules with this functionality, simply subclass MixModule and define your module as you normally would:

from torchmix import MixModule

class CustomModule(MixModule):
    def __init__(self, num_heads, dim, depth):
        pass

custom_module = CustomModule(16, 768, 12)
custom_module.store(group="model", name="custom")

Documentation

Documentation is currently in progress. Please stay tuned! 🚀

Contributing

We welcome contributions to the torchmix library. If you have ideas for new components or suggestions for improving the library, don't hesitate to open an issue or start a discussion. Please note that torchmix is still in the prototype phase, so any contributions should be considered experimental.

License

torchmix is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmix-0.1.0rc2.tar.gz (14.6 kB view hashes)

Uploaded Source

Built Distribution

torchmix-0.1.0rc2-py3-none-any.whl (15.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page