No project description provided
Project description
ModuleGraph
Installation
pip install torch-compose
Introduction
The ModuleGraph project is a Python-based tool that facilitates managing and manipulating a network of PyTorch modules. The tool employs directed graphs, allowing each module to use the output of another as its input, thus accommodating complex data flows within a neural network.
ModuleGraph is ideal in cases where typical sequential or parallel module arrangements don't suffice, allowing intricate dependencies between modules and handling data propagation in the correct order.
Features
DirectedModule: This versatile abstract base class represents a module that receives and produces data. It requires implementing a forward function within its subclass. It can be integrated into a module graph, providing a controlled data flow from one module to another.
The DirectedModule accepts and produces data using defined keys, providing fine control over data partitioning and labeling. It can wrap an existing PyTorch module, reference an existing function as the forward method, or be subclassed for custom forward method implementation, offering great flexibility for various deep learning tasks.
ModuleGraph: This class defines a graph of DirectedModules. The data can flow from one module to another according to the topological order of dependencies. The ModuleGraph takes care of ensuring that the modules are correctly sorted to respect their dependencies.
The ModuleGraph class also provides a visualization method show_graph() to visually inspect the graph structure and dependencies using the networkx and matplotlib libraries. This visual representation shows modules as nodes and dependencies as directed edges, labeled with the corresponding keys.
Usage
To use the ModuleGraph project, you need to define your DirectedModules and then pass them as a dictionary to the ModuleGraph class.
Below is a simplified example:
Example 1: Using module argument to DirectedModule
class MyModule(DirectedModule):
def forward(self, x):
return torch.relu(x)
# create DirectedModules
mod1 = DirectedModule(input_keys='input', output_keys='hidden', module=MyModule())
mod2 = DirectedModule(input_keys='hidden', output_keys='output', module=MyModule())
# create ModuleGraph
graph = ModuleGraph({'mod1': mod1, 'mod2': mod2})
# forward pass
output = graph.forward({'input': torch.randn(10, 10)})
# visualize the graph
graph.show_graph()
DirectedModule can be integrated with torch code in several additional ways.`
Example 2: Using forward argument to DirectedModule
# Defining a function that applies a non-linear activation (ReLU) to its input
def my_relu(input_tensor):
return torch.nn.functional.relu(input_tensor)
# Create a DirectedModule that wraps this function
relu_module = DirectedModule(input_keys='input', output_keys='output', forward=my_relu)
# Use the module in a graph
graph = ModuleGraph({'relu': relu_module})
# forward pass
output = graph.forward({'input': torch.randn(10, 10)})
In this example, the relu_module does not encapsulate a PyTorch nn.Module, but rather a standalone function that performs an operation using PyTorch functionalities.
Example 3: Using DirectedModule in inheritance
# Defining a custom module by inheriting from both nn.Module and DirectedModule
class MyCustomModule(nn.Module, DirectedModule):
def __init__(self, input_keys, output_keys):
nn.Module.__init__(self)
DirectedModule.__init__(self, input_keys, output_keys)
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
# Create an instance of the custom module
custom_module = MyCustomModule(input_keys='input', output_keys='output')
# Use the module in a graph
graph = ModuleGraph({'custom_module': custom_module})
# forward pass
output = graph.forward({'input': torch.randn(10, 10)})
In this example, MyCustomModule is a subclass of both nn.Module and DirectedModule. It uses nn.Module to define a simple linear layer and DirectedModule to handle input/output key mapping.
Example 4: Using DirectedModule as a mixin
# Defining a custom module by inheriting from both nn.Module and DirectedModule
class MyMixinModule(nn.Module):
def __init__(self, input_keys, output_keys):
super().__init__()
self.directed_module = DirectedModule(input_keys=input_keys, output_keys=output_keys, module=self)
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
# Create an instance of the mixin module
mixin_module = MyMixinModule(input_keys='input', output_keys='output')
# Use the module in a graph
graph = ModuleGraph({'mixin_module': mixin_module.directed_module})
# forward pass
output = graph.forward({'input': torch.randn(10, 10)})
In this example, MyMixinModule is a subclass of nn.Module, and it uses a DirectedModule as a member to handle input/output key mapping, allowing the same functionality without multiple inheritance.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_compose-0.2.0.tar.gz.
File metadata
- Download URL: torch_compose-0.2.0.tar.gz
- Upload date:
- Size: 259.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.10.6 Linux/5.15.0-56-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3c6d9bcc4f375d8f8be1573e49dc2b8ba8a2ae4bc0bda91557a2b539b7e16082
|
|
| MD5 |
801a4fe68b65c1ffea24d27b4ae79781
|
|
| BLAKE2b-256 |
b9565465c53cf35a82cb0957905e2897d97dbb87081430b13ea7a45184df5b72
|
File details
Details for the file torch_compose-0.2.0-py3-none-any.whl.
File metadata
- Download URL: torch_compose-0.2.0-py3-none-any.whl
- Upload date:
- Size: 257.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.5.1 CPython/3.10.6 Linux/5.15.0-56-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
354b0940f9ecd26f175b3713b4841298692ef7a51f5d6865a0136ed4dde00e80
|
|
| MD5 |
31488126c733a1ab6649e1aafc9b27b7
|
|
| BLAKE2b-256 |
0807b65be8204b877c1c9c88c3b63d3d7ec13bef68682fe8111dcd3179936a13
|