Skip to main content

A clean, plug-and-play Conformer encoder with rotary positional embeddings.

Project description

rope-conformer

A clean, plug-and-play PyTorch Conformer encoder with rotary positional embeddings (RoPE).

import torch
from rope_conformer import RoPEConformer

model = RoPEConformer(dim=256, depth=6, heads=8, dim_head=32)
x = torch.randn(2, 100, 256)        # [B, N, dim]
y = model(x)                        # [B, N, dim]

Install

pip install rope-conformer

Optional Apple-silicon acceleration (experimental):

pip install "rope-conformer[mps-sdpa]"

Usage with a padding mask

mask = torch.zeros(2, 100, dtype=torch.bool)
mask[:, 80:] = True                 # last 20 positions padded
y = model(x, key_padding_mask=mask) # masked positions are ignored in attention

API

Argument Default Description
dim Model channels (input and output, unless output_dim is set).
depth Number of stacked Conformer blocks.
dim_head 64 Per-head dimension.
heads 8 Number of attention heads.
ff_mult 4 Feedforward expansion factor.
conv_expansion_factor 2 Pointwise expansion in the conv module.
conv_kernel_size 31 Depthwise conv kernel (Conformer paper default).
attn_dropout 0.0 Dropout inside attention (off by default).
proj_dropout 0.1 Dropout after attention output projection.
ff_dropout 0.1 Dropout in feedforward sublayers.
conv_dropout 0.1 Dropout at the end of the conv module.
conv_causal False Left-pad the depthwise conv (no future-frame leakage).
conv_norm_type "rms" "rms" (per-token, causal-safe), "group" or "batch" (cross-time stats; both fall back to Identity in causal mode).
use_attn_gates False Add per-head sigmoid output gates on attention.
flash_attn True Use F.scaled_dot_product_attention; False falls back to einsum.
use_mps_sdpa False Route attention through mps-sdpa (experimental, MPS only).
norm_output True Apply final RMSNorm before projection.
output_dim None Optional output projection to a different dimension.

The forward signature is model(x, key_padding_mask=None):

  • x: [B, N, dim]
  • key_padding_mask: [B, N] bool, True for padded positions (optional).

Granular dropout

The four dropout knobs (attn_dropout, proj_dropout, ff_dropout, conv_dropout) target different stages so you can tune each independently. Attention dropout is off by default because zeroing entries before the softmax distorts the resulting probability distribution and creates a training/inference mismatch in the attention pattern; the other three knobs act on unnormalized intermediate features and don't have this issue, so they default to a small 0.1.

Causal use

Set conv_causal=True for a depthwise conv that only sees past frames. This handles the conv path; the self-attention path is not causally masked by this flag — pass your own causal attn_mask (or use a causal-aware downstream stack) if you need full causal behavior.

Optional mps-sdpa integration (experimental)

PyTorch's MPS backend does not currently dispatch scaled_dot_product_attention to Apple's fused MPSGraph.scaledDotProductAttention op; it builds a naive matmul → softmax → matmul graph instead. The mps-sdpa package wraps the fused op directly, giving roughly 5–7× faster inference and 2–2.5× faster training on Apple silicon (M1+, macOS 15+). When installed via the [mps-sdpa] extra, attention can be routed through it per-instance:

model = RoPEConformer(dim=256, depth=6, use_mps_sdpa=True)

If the mps-sdpa package is not installed, the flag silently no-ops and the standard SDPA path is used. The flag also has no effect on non-MPS devices.

Acknowledgements

The architecture mirrors the modernized transformer conventions used in lucidrains/BS-RoFormer (RoPE, RMSNorm, SDPA, GELU FFN, no biases on QKV/output projections), applied to the Conformer block (Gulati et al., 2020).

License

MIT.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rope_conformer-0.1.0.tar.gz (13.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rope_conformer-0.1.0-py3-none-any.whl (11.6 kB view details)

Uploaded Python 3

File details

Details for the file rope_conformer-0.1.0.tar.gz.

File metadata

  • Download URL: rope_conformer-0.1.0.tar.gz
  • Upload date:
  • Size: 13.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for rope_conformer-0.1.0.tar.gz
Algorithm Hash digest
SHA256 ae63871942cf8ce109c4d0b4b0e0004d4f0a67229814eb5aba06469ac3e6d41c
MD5 95c1b62c4c6bcce0774bf096b25a562d
BLAKE2b-256 1b27ab5f0e2f7511733a10cbe4b333ab99b06f10783086f6e8b1243a51bf3ecb

See more details on using hashes here.

Provenance

The following attestation bundles were made for rope_conformer-0.1.0.tar.gz:

Publisher: pypi.yml on crlandsc/rope-conformer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rope_conformer-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: rope_conformer-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 11.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for rope_conformer-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a30f355d655119aa2083cf31f58d50e532dd86f6790bd66cb073323749c69c6d
MD5 109410734b9dacc0758c6f9ce0fe8774
BLAKE2b-256 f0033602cd38c218716dbb09f3242d5d8d65c47ab0960a6065b8f24e4f748a1d

See more details on using hashes here.

Provenance

The following attestation bundles were made for rope_conformer-0.1.0-py3-none-any.whl:

Publisher: pypi.yml on crlandsc/rope-conformer

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page