Skip to main content

A Python package for generating ML layers for MLIPs

Project description

KLay logo


KLayComposable layers for KLIFF & ML IPs

PyPI Documentation Status

KLay takes the “LEGO-block” approach to machine-learning interatomic potentials (MLIPs): every layer is a first-class citizen you can swap, re-wire and re-train — no copy-pasting opaque model code.

  • Works out-of-the-box with KLIFF
  • Converts a single YAML file → torch.fx.GraphModule (ready for torch.compile, TorchScript, ONNX...)
  • Or: returns a dict of instantiated layers when you just want the bricks
  • Built-in validate, visualize, export, layers & types CLI commands
  • Ships NequIP blocks today; MACE, EGNN & M3GNet on the roadmap

Installation

pip install klay
# or: dev version
pip install git+https://github.com/openkim/klay.git

Usage in Python

from klay.io import load_config
from klay.builder import build_model

cfg   = load_config("example/new_model.yaml")
model = build_model(cfg)   # GraphModule (because cfg has inputs/outputs)

Need only the bricks?

from klay.builder import build_layers
from klay.io import load_config
layers = build_layers(load_config("example/new_model_layers.yaml"))
print(layers.keys())               # dict_keys([...])

Command-line utilities

Command What it does Most common flags
klay layers Pretty table of every registered layer, showing inputs / outputs and from_config signature with coloured required/optional args. --type embedding --all
klay types Lists all ModuleCategory values (convolution, embedding, …).
klay validate Cycle detection, missing sources, dangling layers, alias→alias detection, unused outputs plus optional Graphviz diagram. --allow-dangling -v/--visualize --fmt svg
klay export Builds a full model, TorchScripts it (.pt) or dumps the state_dict (.pth). -o out.pt --format state_dict -n 10

Examples:

# 1. Inspect all embedding layers
klay layers --type embedding

result:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer                           ┃ Inputs                 ┃ Outputs              ┃ from_config args      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩
│ BesselBasis                     │ x                      │ y                    │ r_max,                │
│                                 │                        │                      │ num_radial_basis=8,   │
│                                 │                        │                      │ trainable=True        │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ BinaryAtomicNumberEncoding      │ x                      │ representation       │ -                     │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ ElectronicConfigurationEncoding │ x                      │ representation       │ -                     │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ OneHotAtomEncoding              │ x                      │ representation       │ num_elems,            │
│                                 │                        │                      │ input_is_atomic_numb… │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ PolynomialCutoff                │ x                      │ y                    │ r_max,                │
│                                 │                        │                      │ polynomial_degree=6   │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ RadialBasisEdgeEncoding         │ edge_length            │ edge_length_embedded │ r_max,                │
│                                 │                        │                      │ num_radial_basis=8,   │
│                                 │                        │                      │ polynomial_degree=6,  │
│                                 │                        │                      │ radial_basis_trainab… │
│                                 │                        │                      │ basis='BesselBasis',  │
│                                 │                        │                      │ cutoff='PolynomialCu… │
│                                 │                        │                      │ basis_kwargs={},      │
│                                 │                        │                      │ cutoff_kwargs={}      │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ SphericalHarmonicEdgeAttrs      │ pos, edge_index, shift │ 0: edge_vec          │ lmax=1,               │
│                                 │                        │ 1: edge_length       │ normalization='compo… │
│                                 │                        │ 2: edge_sh           │                       │
└─────────────────────────────────┴────────────────────────┴──────────────────────┴───────────────────────┘
# 2. Strict validation + PNG diagram
klay validate example/new_model.yaml -v

result: Klay validate example

Other Examples:

# 3. Script & save the model
klay export example/new_model.yaml

YAML cheat-sheet

model_params:
  r_max:      4.0
  n_channels: 32
  num_elems:  2

model_inputs:                      # (omit for "library mode")
  atomic_numbers: "Tensor (N,)"
  positions:      "Tensor (N,3)"
  edge_index:     "Tensor (2,E)"
  shifts:         "Tensor (E,3)"

model_layers:
  element_embedding:
    type: OneHotAtomEncoding
    config: {num_elems: ${model_params.num_elems}}

  edge_feature0:
    type: SphericalHarmonicEdgeAttrs
    config: {lmax: 1}
    output:
      0: vec0
      1: len0
      2: sh0

  radial_basis_func:
    type: RadialBasisEdgeEncoding
    config: {r_max: ${model_params.r_max}}
    inputs: {edge_length: len0}

  node_features:
    type: AtomwiseLinear
    config:
      irreps_in_block:  [{l:0, mul:${model_params.num_elems}}]
      irreps_out_block: [{l:0, mul:${model_params.n_channels}}]

  conv_shared:
    type: ConvNetLayer
    config:
      hidden_irreps_lmax: 1
      edge_sh_lmax:       1
      conv_feature_size:  ${model_params.n_channels}

  conv1:
    alias: conv_shared
    inputs:
      h:        node_features
      edge_sh:  sh0
      edge_length_embeddings: radial_basis_func

  output_projection:
    type: AtomwiseLinear
    config:
      irreps_in_block:
        - {l:0, mul:${model_params.n_channels}}
        - {l:1, mul:${model_params.n_channels}}
      irreps_out_block:
        - {l:0, mul:1}

model_outputs:
  energy:          output_projection
  representation:  conv1.h

Roadmap

  • EGNN & SEGNN blocks
  • Pre-trained M3GNet & GemNet-T embeddings
  • ONNX export & torch.export backend
  • cuEquivariance/OpenEquivariance backend for the layers

Pull requests welcome.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

klay-0.9.0.tar.gz (38.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

klay-0.9.0-py3-none-any.whl (53.1 kB view details)

Uploaded Python 3

File details

Details for the file klay-0.9.0.tar.gz.

File metadata

  • Download URL: klay-0.9.0.tar.gz
  • Upload date:
  • Size: 38.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for klay-0.9.0.tar.gz
Algorithm Hash digest
SHA256 5781fa46c0d9a2190ecf2af9326a540791af5b624c4dcdde870cbf78e62626c8
MD5 463656b060d56500d958c93103bc6ce2
BLAKE2b-256 53a1b68a032423143958ba5bfb13f7f4c85b0f22a9b735385e1d2e90d10fa6f4

See more details on using hashes here.

File details

Details for the file klay-0.9.0-py3-none-any.whl.

File metadata

  • Download URL: klay-0.9.0-py3-none-any.whl
  • Upload date:
  • Size: 53.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for klay-0.9.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8cf44f287a4b6827223675f9fefa33564ddf0e998c35559d9351189bb99688a3
MD5 5a60fa91f65f4db5f4d09f99c10ff234
BLAKE2b-256 8a2345f9a5a1f39c6f58cb56a7ebd9c6288387045b72784ef9a3b10853a50fd9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page