Skip to main content

No project description provided

Project description

ModuleGraph

Installation

pip install torch-compose

Introduction

The ModuleGraph project is a Python-based tool that facilitates managing and manipulating a network of PyTorch modules. The tool employs directed graphs, allowing each module to use the output of another as its input, thus accommodating complex data flows within a neural network.

ModuleGraph is ideal in cases where typical sequential or parallel module arrangements don't suffice, allowing intricate dependencies between modules and handling data propagation in the correct order.

Features

DirectedModule: This versatile abstract base class represents a module that receives and produces data. It requires implementing a forward function within its subclass. It can be integrated into a module graph, providing a controlled data flow from one module to another.

The DirectedModule accepts and produces data using defined keys, providing fine control over data partitioning and labeling. It can wrap an existing PyTorch module, reference an existing function as the forward method, or be subclassed for custom forward method implementation, offering great flexibility for various deep learning tasks.

ModuleGraph: This class defines a graph of DirectedModules. The data can flow from one module to another according to the topological order of dependencies. The ModuleGraph takes care of ensuring that the modules are correctly sorted to respect their dependencies.

The ModuleGraph class also provides a visualization method show_graph() to visually inspect the graph structure and dependencies using the networkx and matplotlib libraries. This visual representation shows modules as nodes and dependencies as directed edges, labeled with the corresponding keys.

Usage

To use the ModuleGraph project, you need to define your DirectedModules and then pass them as a dictionary to the ModuleGraph class.

Below is a simplified example:

Example 1: Using module argument to DirectedModule

class MyModule(DirectedModule):
    def forward(self, x):
        return torch.relu(x)

# create DirectedModules
mod1 = DirectedModule(input_keys='input', output_keys='hidden', module=MyModule())
mod2 = DirectedModule(input_keys='hidden', output_keys='output', module=MyModule())

# create ModuleGraph
graph = ModuleGraph({'mod1': mod1, 'mod2': mod2})

# forward pass
output = graph.forward({'input': torch.randn(10, 10)})

# visualize the graph
graph.show_graph()

DirectedModule can be integrated with torch code in several additional ways.`

Example 2: Using forward argument to DirectedModule

# Defining a function that applies a non-linear activation (ReLU) to its input
def my_relu(input_tensor):
    return torch.nn.functional.relu(input_tensor)

# Create a DirectedModule that wraps this function
relu_module = DirectedModule(input_keys='input', output_keys='output', forward=my_relu)

# Use the module in a graph
graph = ModuleGraph({'relu': relu_module})

# forward pass
output = graph.forward({'input': torch.randn(10, 10)})

In this example, the relu_module does not encapsulate a PyTorch nn.Module, but rather a standalone function that performs an operation using PyTorch functionalities.

Example 3: Using DirectedModule in inheritance

# Defining a custom module by inheriting from both nn.Module and DirectedModule
class MyCustomModule(nn.Module, DirectedModule):
    def __init__(self, input_keys, output_keys):
        nn.Module.__init__(self)
        DirectedModule.__init__(self, input_keys, output_keys)
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# Create an instance of the custom module
custom_module = MyCustomModule(input_keys='input', output_keys='output')

# Use the module in a graph
graph = ModuleGraph({'custom_module': custom_module})

# forward pass
output = graph.forward({'input': torch.randn(10, 10)})

In this example, MyCustomModule is a subclass of both nn.Module and DirectedModule. It uses nn.Module to define a simple linear layer and DirectedModule to handle input/output key mapping.

Example 4: Using DirectedModule as a mixin

# Defining a custom module by inheriting from both nn.Module and DirectedModule
class MyMixinModule(nn.Module):
    def __init__(self, input_keys, output_keys):
        super().__init__()
        self.directed_module = DirectedModule(input_keys=input_keys, output_keys=output_keys, module=self)
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# Create an instance of the mixin module
mixin_module = MyMixinModule(input_keys='input', output_keys='output')

# Use the module in a graph
graph = ModuleGraph({'mixin_module': mixin_module.directed_module})

# forward pass
output = graph.forward({'input': torch.randn(10, 10)})

In this example, MyMixinModule is a subclass of nn.Module, and it uses a DirectedModule as a member to handle input/output key mapping, allowing the same functionality without multiple inheritance.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_compose-0.2.0.tar.gz (259.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_compose-0.2.0-py3-none-any.whl (257.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_compose-0.2.0.tar.gz.

File metadata

  • Download URL: torch_compose-0.2.0.tar.gz
  • Upload date:
  • Size: 259.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.6 Linux/5.15.0-56-generic

File hashes

Hashes for torch_compose-0.2.0.tar.gz
Algorithm Hash digest
SHA256 3c6d9bcc4f375d8f8be1573e49dc2b8ba8a2ae4bc0bda91557a2b539b7e16082
MD5 801a4fe68b65c1ffea24d27b4ae79781
BLAKE2b-256 b9565465c53cf35a82cb0957905e2897d97dbb87081430b13ea7a45184df5b72

See more details on using hashes here.

File details

Details for the file torch_compose-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torch_compose-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 257.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.5.1 CPython/3.10.6 Linux/5.15.0-56-generic

File hashes

Hashes for torch_compose-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 354b0940f9ecd26f175b3713b4841298692ef7a51f5d6865a0136ed4dde00e80
MD5 31488126c733a1ab6649e1aafc9b27b7
BLAKE2b-256 0807b65be8204b877c1c9c88c3b63d3d7ec13bef68682fe8111dcd3179936a13

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page