Skip to main content

A Python package for generating ML layers for MLIPs

Project description

KLay logo


KLayComposable layers for KLIFF & ML IPs

PyPI Documentation Status

KLay takes the “LEGO-block” approach to machine-learning interatomic potentials (MLIPs): every layer is a first-class citizen you can swap, re-wire and re-train — no copy-pasting opaque model code.

  • Works out-of-the-box with KLIFF
  • Converts a single YAML file → torch.fx.GraphModule (ready for torch.compile, TorchScript, ONNX...)
  • Or: returns a dict of instantiated layers when you just want the bricks
  • Built-in validate, visualize, export, layers & types CLI commands
  • Ships NequIP blocks today; MACE, EGNN & M3GNet on the roadmap

Installation

pip install klay
# or: dev version
pip install git+https://github.com/openkim/klay.git

Usage in Python

from klay.io import load_config
from klay.builder import build_model

cfg   = load_config("example/new_model.yaml")
model = build_model(cfg)   # GraphModule (because cfg has inputs/outputs)

Need only the bricks?

from klay.builder import build_layers
from klay.io import load_config
layers = build_layers(load_config("example/new_model_layers.yaml"))
print(layers.keys())               # dict_keys([...])

Command-line utilities

Command What it does Most common flags
klay layers Pretty table of every registered layer, showing inputs / outputs and from_config signature with coloured required/optional args. --type embedding --all
klay types Lists all ModuleCategory values (convolution, embedding, …).
klay validate Cycle detection, missing sources, dangling layers, alias→alias detection, unused outputs plus optional Graphviz diagram. --allow-dangling -v/--visualize --fmt svg
klay export Builds a full model, TorchScripts it (.pt) or dumps the state_dict (.pth). -o out.pt --format state_dict -n 10

Examples:

# 1. Inspect all embedding layers
klay layers --type embedding

result:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer                           ┃ Inputs                 ┃ Outputs              ┃ from_config args      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩
│ BesselBasis                     │ x                      │ y                    │ r_max,                │
│                                 │                        │                      │ num_radial_basis=8,   │
│                                 │                        │                      │ trainable=True        │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ BinaryAtomicNumberEncoding      │ x                      │ representation       │ -                     │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ ElectronicConfigurationEncoding │ x                      │ representation       │ -                     │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ OneHotAtomEncoding              │ x                      │ representation       │ num_elems,            │
│                                 │                        │                      │ input_is_atomic_numb… │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ PolynomialCutoff                │ x                      │ y                    │ r_max,                │
│                                 │                        │                      │ polynomial_degree=6   │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ RadialBasisEdgeEncoding         │ edge_length            │ edge_length_embedded │ r_max,                │
│                                 │                        │                      │ num_radial_basis=8,   │
│                                 │                        │                      │ polynomial_degree=6,  │
│                                 │                        │                      │ radial_basis_trainab… │
│                                 │                        │                      │ basis='BesselBasis',  │
│                                 │                        │                      │ cutoff='PolynomialCu… │
│                                 │                        │                      │ basis_kwargs={},      │
│                                 │                        │                      │ cutoff_kwargs={}      │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ SphericalHarmonicEdgeAttrs      │ pos, edge_index, shift │ 0: edge_vec          │ lmax=1,               │
│                                 │                        │ 1: edge_length       │ normalization='compo… │
│                                 │                        │ 2: edge_sh           │                       │
└─────────────────────────────────┴────────────────────────┴──────────────────────┴───────────────────────┘
# 2. Strict validation + PNG diagram
klay validate example/new_model.yaml -v

result: Klay validate example

Other Examples:

# 3. Script & save the model
klay export example/new_model.yaml

YAML cheat-sheet

model_params:
  r_max:      4.0
  n_channels: 32
  num_elems:  2

model_inputs:                      # (omit for "library mode")
  atomic_numbers: "Tensor (N,)"
  positions:      "Tensor (N,3)"
  edge_index:     "Tensor (2,E)"
  shifts:         "Tensor (E,3)"

model_layers:
  element_embedding:
    type: OneHotAtomEncoding
    config: {num_elems: ${model_params.num_elems}}

  edge_feature0:
    type: SphericalHarmonicEdgeAttrs
    config: {lmax: 1}
    output:
      0: vec0
      1: len0
      2: sh0

  radial_basis_func:
    type: RadialBasisEdgeEncoding
    config: {r_max: ${model_params.r_max}}
    inputs: {edge_length: len0}

  node_features:
    type: AtomwiseLinear
    config:
      irreps_in_block:  [{l:0, mul:${model_params.num_elems}}]
      irreps_out_block: [{l:0, mul:${model_params.n_channels}}]

  conv_shared:
    type: ConvNetLayer
    config:
      hidden_irreps_lmax: 1
      edge_sh_lmax:       1
      conv_feature_size:  ${model_params.n_channels}

  conv1:
    alias: conv_shared
    inputs:
      h:        node_features
      edge_sh:  sh0
      edge_length_embeddings: radial_basis_func

  output_projection:
    type: AtomwiseLinear
    config:
      irreps_in_block:
        - {l:0, mul:${model_params.n_channels}}
        - {l:1, mul:${model_params.n_channels}}
      irreps_out_block:
        - {l:0, mul:1}

model_outputs:
  energy:          output_projection
  representation:  conv1.h

Roadmap

  • EGNN & SEGNN blocks
  • Pre-trained M3GNet & GemNet-T embeddings
  • ONNX export & torch.export backend
  • cuEquivariance/OpenEquivariance backend for the layers

Pull requests welcome.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

klay-0.8.0.tar.gz (37.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

klay-0.8.0-py3-none-any.whl (52.9 kB view details)

Uploaded Python 3

File details

Details for the file klay-0.8.0.tar.gz.

File metadata

  • Download URL: klay-0.8.0.tar.gz
  • Upload date:
  • Size: 37.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for klay-0.8.0.tar.gz
Algorithm Hash digest
SHA256 a26345de528d149ed17002786068ae885f19df47552d0bbe9fb145d8582f45c1
MD5 00e876d97d1d9bda4bc09c7a56e5a6f0
BLAKE2b-256 9c26404e804d1c9f99c1b703424102322365cee307fd887a68b3adff981dce52

See more details on using hashes here.

File details

Details for the file klay-0.8.0-py3-none-any.whl.

File metadata

  • Download URL: klay-0.8.0-py3-none-any.whl
  • Upload date:
  • Size: 52.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for klay-0.8.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ba8bf104c71f3fb589068ad4956b7df392bcfbcd47c32d06aa73ff8a93955923
MD5 f36088e29307431f10180409ac3936f2
BLAKE2b-256 110c06ac1815a4be0efab6304abd72f0428df5876ce1fb43f37a2d055d4cac28

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page