Skip to main content

Compute writhe and train Writhe-PaiNN score based generative models with torch

Project description

Code accompanying the paper, 'On the application of knot theoretic geometric descriptors to dynamical and generative models'

Computation of the writhe and analysis of polymer coordinate data

The package contains the following ...

  • Numerical routines for computing the writhe using CPU or GPU devices. In either case, computations are (optionally) parallelized over CPU / GPU cores / devices.
  • A class architecture for writhe computation and visualization.
  • An implementation of the novel writhe-based graph attention message passing layer.
  • An implementation of the SE3 equivariant, writhe-PaiNN neural network architecture where SE3 equivariance is acheived by only augmenting invariant graph features.
  • Implementations of the orginial PaiNN architecture and the cPaiNN achitecture for comparison.
  • An implementation of a score based diffusion model to train all architectures.
  • Classes to compute (time-lagged) canonical correlation analysis and visualize results.

For an example of how to use this package to analyze molecular dynamics simulation data, see analysis_example.ipynb in the examples' folder and the mini tutorial below

To train score-based generative models with any of the architectures listed above, see the scripts folder.

Alt text


The main tool in this package is the class:

writhe_tools.writhe.Writhe

This class is instantiated with one argument, xyz, which should be an (N samples, D points or atoms, 3 coordinates) numpy array. For a molecular dynamics trajectory and structure file, the required input can be obtained as shown below.

Here, we use MDTraj to load the trajectory.

import mdtraj as md 
xyz = md.load("example.xtc", top="example.pdb",
              atom_indices=md.load("example.pdb").top.select("name CA")
              ).center_coordinates().xyz

NOTE: it is generally sufficient to compute the writhe using only the coordinates of the alpha carbons. In principle, one could include all backbone atoms or any selection of contiguous atoms or points.

We can now instantiate an instance of the Writhe class.

from writhe_tools.writhe import Writhe
writhe = Writhe(xyz=xyz)

We can then compute the writhe at a given segment length, save the result for later and then restore the class from the saved result to continue analysis or visualization.

# compute the writhe using segment length 1 and default arguments

writhe.compute_writhe(length=1)

# results are stored to the class instance (see details in following cell about compute_writhe)

# save the result with default arguments (None, see below next code block)
writhe.save(path=None, dscr=None)

# restore the calculation at a later time using the CLASS method, load
restored_writhe = Writhe.load("./writhe_data_dict_length_1.pkl")

The results are saved as a pickled python dictionary with a template name that can be modified using the path and dscr (description) arguments of the save function:

f"{path}/{dscr}_writhe_data_dict_length_{self.length}.pkl"

Or if path and dscr are left to None:

f"./writhe_data_dict_length_{self.length}.pkl"

**Example Usage **

  • NOTE it is recommended to access and plot writhe data using the class to avoid data duplication which may cause memory issues.

The results of the writhe calculation are accessible from the class for further analysis and visualization :

import matplotlib.pyplot as plt

# compute at length 5
writhe.compute_writhe(length=5)
fig, axes = plt.subplots(1, 2, figsize=(14, 3))
ax = axes.flat
writhe.plot_writhe_matrix(index=None, ax=ax[0], label_stride=8)  #xticks=residues, yticks=residues, to match example
writhe.plot_writhe_total(window=250, ax=ax[1])
ax[1].hlines(0, 0, len(xyz), ls="--", color="gray")
fig.tight_layout()

Alt text

from writhe_tools.tcca import tCCA
from writhe_tools.plots import fes2d
tcca = tCCA(writhe.writhe_features, lag=30).fit()
print(f"VAMP2 Score (dim 10) : {(tcca.svals[:10]**2).sum()}")
projection = tcca.transform(dim=2, scale=False)
fes2d(projection)

Writhe.compute_writhe

Description

compute_writhe is a method of the Writhe class that computes the writhe for a given segment length between all segments using parallel computation on CPU (Ray or Numba) or GPU (CUDA).

Method Signature

def compute_writhe(self,
                   length: Optional[int] = None,
                   segments: Optional[np.ndarray] = None,
                   matrix: bool = False,
                   store_results: bool = True,
                   xyz: Optional[np.ndarray] = None,
                   n_points: Optional[int] = None,
                   speed_test: bool = False,
                   cpus_per_job: int = 1,
                   cuda: bool = False,
                   cuda_batch_size: Optional[int] = None,
                   multi_proc: bool = True,
                   use_cross: bool = True,
                   cpu_method: str = "ray"
                   ) -> Optional[dict]:

Arguments

Parameter Type Default Description
length Optional[int] Required if segments is None Segment length for computation.Prefered method of obtaining segments
segments Optional[np.ndarray] Required if length is None Segments to use in computation. General uses should leave this to None and provide the length (int) arg to generate the segments automatically.
matrix bool False If True, generates a symmetric writhe matrix. Generating the full redndant matrix should be avoided and only done transiently for plotting! Using the class method plot_writhe_matrix is preferred
store_results bool True If True, stores results in the Writhe instance.
xyz Optional[np.ndarray] None Coordinate array used for computation. If None, uses self.xyz.
n_points Optional[int] None Number of points in the topology. Defaults to xyz.shape[1].
speed_test bool False If True, performs a benchmark test without storing results.
cpus_per_job int 1 Number of CPUs allocated per batch.
cuda bool False If True, enables CUDA acceleration for GPU computation.
cuda_batch_size Optional[int] None Batch size for CUDA computation.
multi_proc bool True If True, enables multiprocessing (parallel execution).
use_cross bool True If True, uses cross product in computation.
cpu_method str "ray" CPU computation method ("ray" for multiprocessing, "numba" for JIT-compiled CPU execution). 'ray' is substantially faster in most cases.

Returns

A dict containing the writhe computation results:

  • NOTE It is best to store the results in the class (store_result=True) and not set a variable to the output!
Key Type Description
length int The segment length used for computation.
n_points int Number of points in the topology.
n int Number of frames in xyz.
writhe_features np.ndarray Computed writhe values for all segments.
segments np.ndarray The set of segments used in the computation.
(Optional) writhe_matrix np.ndarray If matrix=True, returns a symmetric writhe matrix.

If speed_test=True, the function returns None and doesn't store results.


Additional Notes

  • Calculation can be performed on multiple CPU cores (multi_proc=True) or GPU devices (cuda=True, multi_proc=True).
  • If using CUDA, it is recommended (but not necessarily required) to :
    • Avoid interactive environments like Jupyter notebooks, as they may not properly clear GPU memory.
    • Prepare to manually set cuda_batch_size to avoid out-of-memory (OOM) errors.

The class also has plotting methods with many options

writhe.plot_writhe_matrix(
    # (Averages the writhe matrix across frames by default)
    index=None,  # index: Optional[Union[int, List[int], str, np.ndarray]] = None
    # (Plots the average writhe matrix if index is None)
    absolute=False,  # absolute: bool = False
    # (Uses signed writhe values by default)
    xlabel=None,  # xlabel: Optional[str] = None
    # (No custom label for the x-axis, default will be used)
    ylabel=None,  # ylabel: Optional[str] = None
    # (No custom label for the y-axis, default will be used)
    xticks=None,  # xticks: Optional[np.ndarray] = None
    # (No custom xticks provided, default will be used)
    yticks=None,  # yticks: Optional[np.ndarray] = None
    # (No custom yticks provided, default will be used)
    label_stride=5,  # label_stride: int = 5
    # (Tick labels will be spaced every 5 units by default)
    dscr=None,  # dscr: Optional[str] = None
    # (No description for the subset of frames averaged)
    font_scale=1,  # font_scale: float = 1
    # (Font size will be at the default scale)
    ax=None  # ax: Optional[plt.Axes] = None
    # (No custom Axes object provided, so a new figure will be created)
)

writhe.plot_writhe_per_segment(
    # (Averages over all frames by default)
    index=None,  # index: Optional[Union[int, List[int], str, np.ndarray]] = None
    # (Plots the average writhe per segment if index is None)
    xticks=None,  # xticks: Optional[List[str]] = None
    # (No custom xticks are provided; default range is used)
    label_stride=5,  # label_stride: int = 5
    # (Tick labels are spaced every 5 segments by default)
    dscr=None,  # dscr: Optional[str] = None
    # (No description for the averaged indices)
    ax=None  # ax: Optional[plt.Axes] = None
    # (No custom Axes object provided; a new figure will be created)
)

self.plot_writhe_total(window=None, ax=None)

Project details


Release history Release notifications | RSS feed

This version

1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

writhe_package_lite-1.tar.gz (72.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

writhe_package_lite-1-py3-none-any.whl (77.8 kB view details)

Uploaded Python 3

File details

Details for the file writhe_package_lite-1.tar.gz.

File metadata

  • Download URL: writhe_package_lite-1.tar.gz
  • Upload date:
  • Size: 72.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.14

File hashes

Hashes for writhe_package_lite-1.tar.gz
Algorithm Hash digest
SHA256 86b720926be72f29708cf3b48532f90b8aaad0cb321a49a029eb477efec04c99
MD5 fbc12aac9acc9885e990098b54e8fec9
BLAKE2b-256 99a426c851ddf224107b3f3f2f87d4960dcc32da6b94ee1b11680fba23eac36c

See more details on using hashes here.

File details

Details for the file writhe_package_lite-1-py3-none-any.whl.

File metadata

File hashes

Hashes for writhe_package_lite-1-py3-none-any.whl
Algorithm Hash digest
SHA256 ab4f308045ed9fbb3c7b8a22a2d1b37628b3fa762e19fdad0efd7763d433ec1a
MD5 31b4a9ad1bb365d79d7b47024c17f537
BLAKE2b-256 986da4cf2dfa9b269204ed6068d072d7f1aefde81a466175b0026644108195f1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page