Skip to main content

Build PyTorch models by specifying architectures.

Project description

TorchArc

Test

Build PyTorch models by specifying architectures.

Installation

Bring your own PyTorch, then install this package:

pip install torcharc

Usage

  1. specify model architecture in a YAML spec file, e.g. at spec_filepath = "./example/spec/basic/mlp.yaml"
  2. import torcharc.
    1. (optional) if you have custom torch.nn.Module, e.g. MyModule, register it with torcharc.register_nn(MyModule)
  3. build with: model = torcharc.build(spec_filepath)

The returned model is a PyTorch nn.Module, fully-compatible with torch.compile, and mostly compatible with PyTorch JIT script and trace.

See more examples below, then see how it works at the end.


Example: build model from spec file

from pathlib import Path

import torch
import yaml

import torcharc


filepath = Path(".") / "torcharc" / "example" / "spec" / "basic" / "mlp.yaml"

# The following are equivalent:

# 1. build from YAML spec file
model = torcharc.build(filepath)

# 2. build from dictionary
with filepath.open("r") as f:
    spec_dict = yaml.safe_load(f)
model = torcharc.build(spec_dict)

# 3. use the underlying Pydantic validator to build the model
spec = torcharc.Spec(**spec_dict)
model = spec.build()

Example: basic MLP

spec file

File: torcharc/example/spec/basic/mlp.yaml

modules:
  mlp:
    Sequential:
      - Linear:
          in_features: 128
          out_features: 64
      - ReLU:
      - Linear:
          in_features: 64
          out_features: 10

graph:
  input: x
  modules:
    mlp: [x]
  output: mlp
model = torcharc.build(torcharc.SPEC_DIR / "basic" / "mlp.yaml")
assert isinstance(model, torch.nn.Module)

# Run the model and check the output shape
x = torch.randn(4, 128)
y = model(x)
assert y.shape == (4, 10)

# Test compatibility with compile, script and trace
compiled_model = torch.compile(model)
assert compiled_model(x).shape == y.shape
scripted_model = torch.jit.script(model)
assert scripted_model(x).shape == y.shape
traced_model = torch.jit.trace(model, (x))
assert traced_model(x).shape == y.shape

model
model
GraphModule(
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=10, bias=True)
  )
)


Example: MLP (Lazy)

spec file

File: torcharc/example/spec/basic/mlp_lazy.yaml

# NOTE lazy modules is recommended for ease https://pytorch.org/docs/stable/generated/torch.nn.LazyLinear.html
modules:
  mlp:
    Sequential:
      - LazyLinear:
          out_features: 64
      - ReLU:
      - LazyLinear:
          out_features: 10

graph:
  input: x
  modules:
    mlp: [x]
  output: mlp
# PyTorch Lazy layers will infer the input size from the first forward pass. This is recommended as it greatly simplifies the model definition.
model = torcharc.build(torcharc.SPEC_DIR / "basic" / "mlp_lazy.yaml")
model  # shows LazyLinear before first forward pass
model (before forward pass)
GraphModule(
  (mlp): Sequential(
    (0): LazyLinear(in_features=0, out_features=64, bias=True)
    (1): ReLU()
    (2): LazyLinear(in_features=0, out_features=10, bias=True)
  )
)
# Run the model and check the output shape
x = torch.randn(4, 128)
y = model(x)
assert y.shape == (4, 10)

# Because it is lazy - wait till first forward pass to run compile, script or trace
compiled_model = torch.compile(model)

model  # shows Linear after first forward pass
model (after forward pass)
GraphModule(
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=10, bias=True)
  )
)


Example: MNIST conv

spec file

File: torcharc/example/spec/mnist/conv.yaml

# MNIST Conv2d model example from PyTorch https://github.com/pytorch/examples/blob/main/mnist/main.py
modules:
  conv:
    Sequential:
      - LazyConv2d:
          out_channels: 32
          kernel_size: 3
      - ReLU:
      - LazyConv2d:
          out_channels: 64
          kernel_size: 3
      - ReLU:
      - MaxPool2d:
          kernel_size: 2
      - Dropout:
          p: 0.25
      - Flatten:
      - LazyLinear:
          out_features: 128
      - ReLU:
      - Dropout:
          p: 0.5
      - LazyLinear:
          out_features: 10
      - LogSoftmax:
          dim: 1

graph:
  input: image
  modules:
    conv: [image]
  output: conv
model = torcharc.build(torcharc.SPEC_DIR / "mnist" / "conv.yaml")

# Run the model and check the output shape
x = torch.randn(4, 1, 28, 28)
y = model(x)
assert y.shape == (4, 10)

model
model
GraphModule(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Dropout(p=0.25, inplace=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=9216, out_features=128, bias=True)
    (8): ReLU()
    (9): Dropout(p=0.5, inplace=False)
    (10): Linear(in_features=128, out_features=10, bias=True)
    (11): LogSoftmax(dim=1)
  )
)


Example: MLP (Compact)

Use compact spec that expands into Sequential spec - this is useful for architecture search.

spec file

File: torcharc/example/spec/compact/mlp.yaml

# modules:
#   mlp:
#     Sequential:
#       - LazyLinear:
#           out_features: 64
#       - ReLU:
#       - LazyLinear:
#           out_features: 64
#       - ReLU:
#       - LazyLinear:
#           out_features: 32
#       - ReLU:
#       - LazyLinear:
#           out_features: 16
#       - ReLU:

# the above can be written compactly as follows

modules:
  mlp:
    compact:
      layer:
        type: LazyLinear
        keys: [out_features]
        args: [64, 64, 32, 16]
      postlayer:
        - ReLU:

graph:
  input: x
  modules:
    mlp: [x]
  output: mlp
model = torcharc.build(torcharc.SPEC_DIR / "compact" / "mlp.yaml")

# Run the model and check the output shape
x = torch.randn(4, 128)
y = model(x)
assert y.shape == (4, 16)

model
model
GraphModule(
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): Linear(in_features=32, out_features=16, bias=True)
    (7): ReLU()
  )
)


Example: Conv (Compact)

Use compact spec that expands into Sequential spec - this is useful for architecture search.

spec file

File: torcharc/example/spec/compact/conv.yaml

# modules:
#   conv:
#     Sequential:
#       - LazyBatchNorm2d:
#       - LazyConv2d:
#           out_channels: 16
#           kernel_size: 2
#       - ReLU:
#       - Dropout:
#           p: 0.1
#       - LazyBatchNorm2d:
#       - LazyConv2d:
#           out_channels: 32
#           kernel_size: 3
#       - ReLU:
#       - Dropout:
#           p: 0.1
#       - LazyBatchNorm2d:
#       - LazyConv2d:
#           out_channels: 64
#           kernel_size: 4
#       - ReLU:
#       - Dropout:
#           p: 0.1
#   classifier:
#     Sequential:
#       - Flatten:
#       - LazyLinear:
#           out_features: 10

# the above can be written compactly as follows

modules:
  conv:
    compact:
      prelayer:
        - LazyBatchNorm2d:
      layer:
        type: LazyConv2d
        keys: [out_channels, kernel_size]
        args: [[16, 2], [32, 3], [64, 4]]
      postlayer:
        - ReLU:
        - Dropout:
            p: 0.1
  classifier:
    Sequential:
      - Flatten:
      - LazyLinear:
          out_features: 10

graph:
  input: image
  modules:
    conv: [image]
    classifier: [conv]
  output: classifier
model = torcharc.build(torcharc.SPEC_DIR / "compact" / "conv_classifier.yaml")

# Run the model and check the output shape
x = torch.randn(4, 1, 28, 28)
y = model(x)
assert y.shape == (4, 10)

model
model
GraphModule(
  (conv): Sequential(
    (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Conv2d(1, 16, kernel_size=(2, 2), stride=(1, 1))
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
    (7): Dropout(p=0.1, inplace=False)
    (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Conv2d(32, 64, kernel_size=(4, 4), stride=(1, 1))
    (10): ReLU()
    (11): Dropout(p=0.1, inplace=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=30976, out_features=10, bias=True)
  )
)


Example: Reuse syntax: Stereo Conv

spec file

File: torcharc/example/spec/basic/stereo_conv_reuse.yaml

# stereoscopic vision conv with shared (reused) conv
modules:
  # single conv shared for left and right
  conv:
    Sequential:
      - LazyConv2d:
          out_channels: 32
          kernel_size: 3
      - ReLU:
      - LazyConv2d:
          out_channels: 64
          kernel_size: 3
      - ReLU:
      - MaxPool2d:
          kernel_size: 2
      - Dropout:
          p: 0.25
      - Flatten:

  # separate identical mlp for left and right for processing
  left_mlp: &left_mlp
    Sequential:
      - LazyLinear:
          out_features: 256
      - ReLU:
      - LazyLinear:
          out_features: 10
      - ReLU:
  right_mlp:
    <<: *left_mlp

graph:
  input: [left_image, right_image]
  modules:
    # reuse syntax: <module>~<suffix>
    conv~left: [left_image]
    conv~right: [right_image]
    left_mlp: [conv~left]
    right_mlp: [conv~right]
  output: [left_mlp, right_mlp]
# To use the same module for many passes, specify with the reuse syntax in graph spec: <module>~<suffix>.
# E.g. stereoscopic model with `left` and `right` inputs over a shared `conv` module: `conv~left` and `conv~right`.
model = torcharc.build(torcharc.SPEC_DIR / "basic" / "stereo_conv_reuse.yaml")

# Run the model and check the output shape
left_image = right_image = torch.randn(4, 3, 32, 32)
left, right = model(left_image=left_image, right_image=right_image)
assert left.shape == (4, 10)
assert right.shape == (4, 10)

model
model
GraphModule(
  (conv): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Dropout(p=0.25, inplace=False)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (left_mlp): Sequential(
    (0): Linear(in_features=12544, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=10, bias=True)
    (3): ReLU()
  )
  (right_mlp): Sequential(
    (0): Linear(in_features=12544, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=10, bias=True)
    (3): ReLU()
  )
)


Example: transformer

spec file

File: torcharc/example/spec/transformer/transformer.yaml

modules:
  src_embed:
    LazyLinear:
      out_features: &embed_dim 128
  tgt_embed:
    LazyLinear:
      out_features: *embed_dim
  transformer:
    Transformer:
      d_model: *embed_dim
      nhead: 8
      num_encoder_layers: 2
      num_decoder_layers: 2
      dim_feedforward: 1024
      dropout: 0.1
      batch_first: true
  mlp:
    LazyLinear:
      out_features: 10

graph:
  input: [src_x, tgt_x]
  modules:
    src_embed: [src_x]
    tgt_embed: [tgt_x]
    transformer:
      src: src_embed
      tgt: tgt_embed
    mlp: [transformer]
  output: mlp
model = torcharc.build(torcharc.SPEC_DIR / "transformer" / "transformer.yaml")

# Run the model and check the output shape
src_x = torch.randn(4, 10, 64)
tgt_x = torch.randn(4, 20, 64)
y = model(src_x=src_x, tgt_x=tgt_x)
assert y.shape == (4, 20, 10)

model
model
GraphModule(
  (src_embed): Linear(in_features=64, out_features=128, bias=True)
  (tgt_embed): Linear(in_features=64, out_features=128, bias=True)
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=1024, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=1024, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerDecoderLayer(
          (self_attn): MultiheadAttention(
...
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (mlp): Linear(in_features=128, out_features=10, bias=True)
)

Example: Get modules: Attention

See more in torcharc/module/get.py

spec file

File: torcharc/example/spec/transformer/attn.yaml

modules:
  src_embed:
    LazyLinear:
      out_features: &embed_dim 64
  tgt_embed:
    LazyLinear:
      out_features: *embed_dim
  attn:
    MultiheadAttention:
      embed_dim: *embed_dim
      num_heads: 4
      batch_first: True
  attn_output:
    # attn output is tuple; get the first element to pass to mlp
    Get:
      key: 0
  mlp:
    LazyLinear:
      out_features: 10

graph:
  input: [src_x, tgt_x]
  modules:
    src_embed: [src_x]
    tgt_embed: [tgt_x]
    attn:
      query: src_embed
      key: tgt_embed
      value: tgt_embed
    attn_output: [attn]
    mlp: [attn_output]
  output: mlp
# Some cases need "Get" module, e.g. get first output from MultiheadAttention
model = torcharc.build(torcharc.SPEC_DIR / "transformer" / "attn.yaml")

# Run the model and check the output shape
src_x = torch.rand(4, 10, 64)  # (batch_size, seq_len, embed_dim)
tgt_x = torch.rand(4, 20, 64)  # (batch_size, seq_len, embed_dim)
# Get module gets the first output from MultiheadAttention
y = model(src_x=src_x, tgt_x=tgt_x)
assert y.shape == (4, 10, 10)

model
model
GraphModule(
  (src_embed): Linear(in_features=64, out_features=64, bias=True)
  (tgt_embed): Linear(in_features=64, out_features=64, bias=True)
  (attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
  )
  (attn_output): Get()
  (mlp): Linear(in_features=64, out_features=10, bias=True)
)


Example: Merge modules: DLRM

See more in torcharc/module/merge.py

spec file

File: torcharc/example/spec/advanced/dlrm_sum.yaml

# basic DLRM architecture ref: https://github.com/facebookresearch/dlrm and ref: https://catalog.ngc.nvidia.com/orgs/nvidia/resources/dlrm_for_pytorch
modules:
  # NOTE dense mlp and each embedding needs to be the same size
  dense_mlp:
    Sequential:
      - LazyLinear:
          out_features: 512
      - ReLU:
      - LazyLinear:
          out_features: 256
      - ReLU:
      - LazyLinear:
          out_features: &feat_dim 128
      - ReLU:

  cat_embed_0:
    Embedding:
      num_embeddings: 1000
      embedding_dim: *feat_dim
  cat_embed_1:
    Embedding:
      num_embeddings: 1000
      embedding_dim: *feat_dim
  cat_embed_2:
    Embedding:
      num_embeddings: 1000
      embedding_dim: *feat_dim

  # pairwise interactions (original mentions sum, pairwise dot, cat) - but modern day this can be anything, e.g. self-attention
  merge:
    MergeSum:

  # final classifier for probability of click
  classifier:
    Sequential:
      - LazyLinear:
          out_features: 256
      - ReLU:
      - LazyLinear:
          out_features: 256
      - ReLU:
      - LazyLinear:
          out_features: 1
      - Sigmoid:

graph:
  input: [dense, cat_0, cat_1, cat_2]
  modules:
    dense_mlp: [dense]
    cat_embed_0: [cat_0]
    cat_embed_1: [cat_1]
    cat_embed_2: [cat_2]
    merge: [[dense_mlp, cat_embed_0, cat_embed_1, cat_embed_2]]
    classifier: [merge]
  output: classifier
# For more complex models, merge modules can be used to combine multiple inputs into a single output.
model = torcharc.build(torcharc.SPEC_DIR / "advanced" / "dlrm_sum.yaml")

# Run the model and check the output shape
dense = torch.randn(4, 256)
cat_0 = torch.randint(0, 1000, (4,))
cat_1 = torch.randint(0, 1000, (4,))
cat_2 = torch.randint(0, 1000, (4,))
y = model(dense, cat_0, cat_1, cat_2)
assert y.shape == (4, 1)

model
model
GraphModule(
  (dense_mlp): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): ReLU()
  )
  (cat_embed_0): Embedding(1000, 128)
  (cat_embed_1): Embedding(1000, 128)
  (cat_embed_2): Embedding(1000, 128)
  (merge): MergeSum()
  (classifier): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)


Example: Merge modules: FiLM

spec file

File: torcharc/example/spec/merge/film.yaml

# FiLM https://distill.pub/2018/feature-wise-transformations/
modules:
  feat:
    LazyLinear:
      out_features: &feat_dim 10
  cond:
    LazyLinear:
      out_features: &cond_dim 4
  merge_0_1:
    MergeFiLM:
      feature_dim: *feat_dim
      conditioner_dim: *cond_dim
  tail:
    Linear:
      in_features: *feat_dim
      out_features: 1

graph:
  input: [x_0, x_1]
  modules:
    feat: [x_0]
    cond: [x_1]
    merge_0_1:
      feature: feat
      conditioner: cond
    tail: [merge_0_1]
  output: tail
# Custom MergeFiLM module for Feature-wise Linear Modulation https://distill.pub/2018/feature-wise-transformations/
model = torcharc.build(torcharc.SPEC_DIR / "advanced" / "film_image_state.yaml")

# Run the model and check the output shape
image = torch.randn(4, 3, 32, 32)
state = torch.randn(4, 4)
y = model(image=image, state=state)
assert y.shape == (4, 10)

model
model
GraphModule(
  (conv_0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
  )
  (gyro): Linear(in_features=4, out_features=10, bias=True)
  (film_0): MergeFiLM(
    (gamma): Linear(in_features=10, out_features=64, bias=True)
    (beta): Linear(in_features=10, out_features=64, bias=True)
  )
  (conv_1): Sequential(
    (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0.25, inplace=False)
  )
  (film_1): MergeFiLM(
    (gamma): Linear(in_features=10, out_features=32, bias=True)
    (beta): Linear(in_features=10, out_features=32, bias=True)
  )
  (flatten): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=6272, out_features=256, bias=True)
  )
  (classifier): Sequential(
...
    (1): ReLU()
    (2): Linear(in_features=64, out_features=10, bias=True)
    (3): LogSoftmax(dim=1)
  )
)


Example: Fork modules: Chunk

See more in torcharc/module/fork.py

spec file

File: torcharc/example/spec/fork/chunk.yaml

modules:
  head:
    Linear:
      in_features: 32
      out_features: 10
  fork_0_1:
    ForkChunk:
      chunks: 2
  tail_0:
    Get:
      key: 0
  tail_1:
    Get:
      key: 1

graph:
  input: x
  modules:
    head: [x]
    fork_0_1: [head]
    tail_0: [fork_0_1]
    tail_1: [fork_0_1]
  output: [tail_0, tail_1]
# Fork module can be used to split the input into multiple outputs.
model = torcharc.build(torcharc.SPEC_DIR / "fork" / "chunk.yaml")

# Run the model and check the output shape
x = torch.randn(4, 32)
y = model(x)
assert len(y) == 2  # tail_0, tail_1

model
model
GraphModule(
  (head): Linear(in_features=32, out_features=10, bias=True)
  (fork_0_1): ForkChunk()
  (tail_0): Get()
  (tail_1): Get()
)


Example: reduction op functional modules: Reduce

See more in torcharc/module/fn.py

spec file

File: torcharc/example/spec/fn/reduce_mean.yaml

modules:
  text_token_embed:
    Embedding:
      num_embeddings: 10000
      embedding_dim: &embed_dim 128
  text_tfmr:
    Transformer:
      d_model: *embed_dim
      nhead: 8
      num_encoder_layers: 2
      num_decoder_layers: 2
      dropout: 0.1
      batch_first: true
  text_embed:
    # sentence embedding has shape [batch_size, seq_len, embed_dim],
    # one way to reduce it to a single sentence embedding is by pooling, e.g. reduce with mean over seq
    Reduce:
      name: mean
      dim: 1

graph:
  input: [text]
  modules:
    text_token_embed: [text]
    text_tfmr:
      src: text_token_embed
      tgt: text_token_embed
    text_embed: [text_tfmr]
  output: text_embed
# Sometimes we need to reduce the input tensor, e.g. pooling token embeddings into a single sentence embedding
model = torcharc.build(torcharc.SPEC_DIR / "fn" / "reduce_mean.yaml")

# Run the model and check the output shape
text = torch.randint(0, 1000, (4, 10))
y = model(text)
assert y.shape == (4, 128)  # reduced embedding

model
model
GraphModule(
  (text_token_embed): Embedding(10000, 128)
  (text_tfmr): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerDecoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
...
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (text_embed): Reduce()
)

Example: general functional modules: TorchFn

See more in torcharc/module/fn.py

spec file

File: torcharc/example/spec/fn/fn_topk.yaml

modules:
  mlp:
    Sequential:
      - Flatten:
      - LazyLinear:
          out_features: 512
      - ReLU:
      - LazyLinear:
          out_features: 512
      - ReLU:
      - LazyLinear:
          out_features: 10
      # use generic torch function with caveat of incompatible with JIT script
      - TorchFn:
          name: topk
          k: 3

graph:
  input: x
  modules:
    mlp: [x]
  output: mlp
# While TorchArc uses module-based design, not all torch functions are available as modules. Use TorchFn to wrap any torch function into a module - with the caveat that it does not work with JIT script (but does work with torch compile and JIT trace).
model = torcharc.build(torcharc.SPEC_DIR / "fn" / "fn_topk.yaml")

# Run the model and check the output shape
x = torch.randn(4, 128)
y = model(x)  # topk output with k=3
assert y.indices.shape == (4, 3)
assert y.values.shape == (4, 3)

model
model
GraphModule(
  (mlp): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=128, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=10, bias=True)
    (6): TorchFn()
  )
)

Example: more

See more examples:

How does it work

TorchArc uses fx.GraphModule to build a model from a spec file:

  1. Spec is defined via Pydantic torcharc/validator/. This defines:
    • modules: the torch.nn.Modules as the main compute graph nodes
    • graph: the graph structure, i.e. how the modules are connected
  2. build model using torch.fx.GraphModule(modules, graph)

See more in the pydantic spec definition:

Guiding principles

The design of TorchArc is guided as follows:

  1. simple: the module spec is straightforward:
    1. it is simply torch.nn.Module class name with kwargs.
    2. it includes official torch.nn.Modules, Sequential, and custom-defined modules registered via torcharc.register_nn
    3. it uses modules only for - torcharc is not meant to replace custom functions that is best written in PyTorch code
  2. expressive: it can be used to build both simple and advanced model architecture easily
  3. portable: it returns torch.nn.Module that can be used anywhere; it is not a framework.
  4. performant: PyTorch-native, fully-compatible with torch.compile, and mostly with torch JIT script and trace
  5. parametrizable: data-based model-building unlocks fast experimentation, e.g. by building logic for hyperparameter / architecture search

Development

Setup

Install uv for dependency management if you haven't already. Then run:

# setup virtualenv
uv sync

Unit Tests

uv run pytest

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torcharc-2.1.2.tar.gz (1.1 MB view details)

Uploaded Source

Built Distribution

torcharc-2.1.2-py3-none-any.whl (42.1 kB view details)

Uploaded Python 3

File details

Details for the file torcharc-2.1.2.tar.gz.

File metadata

  • Download URL: torcharc-2.1.2.tar.gz
  • Upload date:
  • Size: 1.1 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.16

File hashes

Hashes for torcharc-2.1.2.tar.gz
Algorithm Hash digest
SHA256 d3edff680bdabf462ca9814bf77bf1726ec87df1cc82a81d4f2bf6ee108e84aa
MD5 bbcd02f9a7ffa6b77d89f6ff2f4682a2
BLAKE2b-256 cccfb8a997a406e7f6cfb6ceaf2847de71a8aea8be793ebecb5def82f32052f9

See more details on using hashes here.

File details

Details for the file torcharc-2.1.2-py3-none-any.whl.

File metadata

  • Download URL: torcharc-2.1.2-py3-none-any.whl
  • Upload date:
  • Size: 42.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.16

File hashes

Hashes for torcharc-2.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 78e99681d3ac529d70c591118f23a9b2047630c1e9097b808995bdf6bcbd4068
MD5 5f15737cebe2f83661a1d7fabecff3d1
BLAKE2b-256 78f43ec0b0a887266d42aa89c52098a54fd07dae58fa8071ef2faf37960d9f7b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page