Skip to main content

A Python package for generating ML layers for MLIPs

Project description

KLay logo


KLayComposable layers for KLIFF & MLIPs

PyPI Documentation Status

KLay takes the “LEGO-block” approach to machine-learning interatomic potentials (MLIPs): every layer is a first-class citizen you can swap, re-wire and re-train — no copy-pasting opaque model code.

  • Works out-of-the-box with KLIFF
  • Converts a single YAML file → torch.fx.GraphModule (ready for torch.compile, TorchScript, ONNX...)
  • Or: returns a dict of instantiated layers when you just want the bricks
  • Built-in validate, visualize, export, layers & types CLI commands
  • Ships NequIP blocks today; MACE, EGNN & M3GNet on the roadmap

Installation

pip install klay
# or: dev version
pip install git+https://github.com/openkim/klay.git

Usage in Python

from klay.io import load_config
from klay.builder import build_model

cfg   = load_config("example/new_model.yaml")
model = build_model(cfg)   # GraphModule (because cfg has inputs/outputs)

Need only the bricks?

from klay.builder import build_layers
from klay.io import load_config
layers = build_layers(load_config("example/new_model_layers.yaml"))
print(layers.keys())               # dict_keys([...])

Command-line utilities

Command What it does Most common flags
klay layers Pretty table of every registered layer, showing inputs / outputs and from_config signature with coloured required/optional args. --type embedding --all
klay types Lists all ModuleCategory values (convolution, embedding, …).
klay validate Cycle detection, missing sources, dangling layers, alias→alias detection, unused outputs plus optional Graphviz diagram. --allow-dangling -v/--visualize --fmt svg
klay export Builds a full model, TorchScripts it (.pt) or dumps the state_dict (.pth). -o out.pt --format state_dict -n 10

Examples:

# 1. Inspect all embedding layers
klay layers --type embedding

result:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer                           ┃ Inputs                 ┃ Outputs              ┃ from_config args      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩
│ BesselBasis                     │ x                      │ y                    │ r_max,                │
│                                 │                        │                      │ num_radial_basis=8,   │
│                                 │                        │                      │ trainable=True        │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ BinaryAtomicNumberEncoding      │ x                      │ representation       │ -                     │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ ElectronicConfigurationEncoding │ x                      │ representation       │ -                     │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ OneHotAtomEncoding              │ x                      │ representation       │ num_elems,            │
│                                 │                        │                      │ input_is_atomic_numb… │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ PolynomialCutoff                │ x                      │ y                    │ r_max,                │
│                                 │                        │                      │ polynomial_degree=6   │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ RadialBasisEdgeEncoding         │ edge_length            │ edge_length_embedded │ r_max,                │
│                                 │                        │                      │ num_radial_basis=8,   │
│                                 │                        │                      │ polynomial_degree=6,  │
│                                 │                        │                      │ radial_basis_trainab… │
│                                 │                        │                      │ basis='BesselBasis',  │
│                                 │                        │                      │ cutoff='PolynomialCu… │
│                                 │                        │                      │ basis_kwargs={},      │
│                                 │                        │                      │ cutoff_kwargs={}      │
├─────────────────────────────────┼────────────────────────┼──────────────────────┼───────────────────────┤
│ SphericalHarmonicEdgeAttrs      │ pos, edge_index, shift │ 0: edge_vec          │ lmax=1,               │
│                                 │                        │ 1: edge_length       │ normalization='compo… │
│                                 │                        │ 2: edge_sh           │                       │
└─────────────────────────────────┴────────────────────────┴──────────────────────┴───────────────────────┘
# 2. Strict validation + PNG diagram
klay validate example/new_model.yaml -v

result: Klay validate example

Other Examples:

# 3. Script & save the model
klay export example/new_model.yaml

YAML cheat-sheet

model_params:
  r_max:      4.0
  n_channels: 32
  num_elems:  2

model_inputs:                      # (omit for "library mode")
  atomic_numbers: "Tensor (N,)"
  positions:      "Tensor (N,3)"
  edge_index:     "Tensor (2,E)"
  shifts:         "Tensor (E,3)"

model_layers:
  element_embedding:
    type: OneHotAtomEncoding
    config: {num_elems: ${model_params.num_elems}}

  edge_feature0:
    type: SphericalHarmonicEdgeAttrs
    config: {lmax: 1}
    output:
      0: vec0
      1: len0
      2: sh0

  radial_basis_func:
    type: RadialBasisEdgeEncoding
    config: {r_max: ${model_params.r_max}}
    inputs: {edge_length: len0}

  node_features:
    type: AtomwiseLinear
    config:
      irreps_in_block:  [{l:0, mul:${model_params.num_elems}}]
      irreps_out_block: [{l:0, mul:${model_params.n_channels}}]

  conv_shared:
    type: ConvNetLayer
    config:
      hidden_irreps_lmax: 1
      edge_sh_lmax:       1
      conv_feature_size:  ${model_params.n_channels}

  conv1:
    alias: conv_shared
    inputs:
      h:        node_features
      edge_sh:  sh0
      edge_length_embeddings: radial_basis_func

  output_projection:
    type: AtomwiseLinear
    config:
      irreps_in_block:
        - {l:0, mul:${model_params.n_channels}}
        - {l:1, mul:${model_params.n_channels}}
      irreps_out_block:
        - {l:0, mul:1}

model_outputs:
  energy:          output_projection
  representation:  conv1.h

Roadmap

  • EGNN & SEGNN blocks
  • Pre-trained M3GNet & GemNet-T embeddings
  • ONNX export & torch.export backend
  • cuEquivariance/OpenEquivariance backend for the layers

Pull requests welcome.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

klay-0.9.7.tar.gz (40.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

klay-0.9.7-py3-none-any.whl (55.2 kB view details)

Uploaded Python 3

File details

Details for the file klay-0.9.7.tar.gz.

File metadata

  • Download URL: klay-0.9.7.tar.gz
  • Upload date:
  • Size: 40.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for klay-0.9.7.tar.gz
Algorithm Hash digest
SHA256 b167ac3021caaa5879e3baea5f85227cfe7c92295131459a99f4490f8afd19dc
MD5 384392d9938a17bfae22b421677413c8
BLAKE2b-256 1c6f3bbccc9cd1126f61c498aee1eba277c2c48e802b77f25c211fb2c3a8d0d5

See more details on using hashes here.

File details

Details for the file klay-0.9.7-py3-none-any.whl.

File metadata

  • Download URL: klay-0.9.7-py3-none-any.whl
  • Upload date:
  • Size: 55.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.22

File hashes

Hashes for klay-0.9.7-py3-none-any.whl
Algorithm Hash digest
SHA256 7dbcc731d18c443936da5ddf75a627f88a9c27d89e3a15d85b29b329677f7291
MD5 fa389a1fe50e3e22dde3e4cb0bc121a3
BLAKE2b-256 cdadd59c20d8e52557973165d14d58897806f9f07ef56f0db38c542ae36a4fd0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page