Skip to main content

Tools for training torch models on gravitational wave data

Project description

ML4GW

Torch utilities for training neural networks in gravitational wave physics applications.

Installation

Pip installation

You can install ml4gw with pip:

pip install ml4gw

To build with a specific version of PyTorch/CUDA, please see the PyTorch installation instructions here to see how to specify the desired torch version and --extra-index-url flag. For example, to install with torch 1.12 and CUDA 11.6 support, you would run

pip install ml4gw torch==1.12.0 --extra-index-url=https://download.pytorch.org/whl/cu116

Poetry installation

ml4gw is also fully compatible with use in Poetry, with your pyproject.toml set up like

[tool.poetry.dependencies]
python = "^3.8"  # python versions 3.8-3.10 are supported
ml4gw = "^0.1.0"

To build against a specific PyTorch/CUDA combination, consult the PyTorch installation documentation above and specify the extra-index-url via the tool.poetry.source table in your pyproject.toml. For example, to build against CUDA 11.6, you would do something like:

[tool.poetry.dependencies]
python = "^3.8"
ml4gw = "^0.1.0"
torch = {version = "^1.12", source = "torch"}

[[tool.poetry.source]]
name = "torch"
url = "https://download.pytorch.org/whl/cu116"
secondary = true
default = false

Use cases

This library provided utilities for both data iteration and transformation via dataloaders defined in ml4gw/dataloading and transform layers exposed in ml4gw/transforms. Lower level functions and utilies are defined at the top level of the library and in the utils library.

For example, to train a simple autoencoder using a cost function in frequency space, you might do something like:

import numpy as np
import torch
from ml4gw.dataloading import InMemoryDataset
from ml4gw.transforms import SpectralDensity

SAMPLE_RATE = 2048
NUM_IFOS = 2
DATA_LENGTH = 128
KERNEL_LENGTH = 4
DEVICE = "cuda"  # or "cpu", wherever you want to run

BATCH_SIZE = 32
LEARNING_RATE = 1e-3
NUM_EPOCHS = 10

dummy_data = np.random.randn(NUM_IFOS, DATA_LENGTH * SAMPLE_RATE)

# this will create a dataloader that iterates through your
# timeseries data sampling 4s long windows of data randomly
# and non-coincidentally: i.e. the background from each IFO
# will be sampled independently
dataset = InMemoryDataset(
    dummy_data,
    kernel_size=KERNEL_LENGTH * SAMPLE_RATE,
    batch_size=BATCH_SIZE,
    batches_per_epoch=50,
    coincident=False,
    shuffle=True,
    device=DEVICE  # this will move your dataset to GPU up-front if "cuda"
)


nn = torch.nn.Sequential(
    torch.nn.Conv1d(
        in_channels=2,
        out_channels=8,
        kernel_size=7
    ),
    torch.nn.ConvTranspose1d(
        in_channels=8,
        out_channels=2,
        kernel_size=7
    )
).to(DEVICE)

optimizer = torch.optim.Adam(nn.parameters(), lr=LEARNING_RATE)

spectral_density = SpectralDensity(SAMPLE_RATE, fftlength=2).to(DEVICE)

def loss_function(X, y):
    """
    MSE in frequency domain. Obviously this doesn't
    give you much on its own, but you can imagine doing
    something like masking to just the bins you care about.
    """
    X = spectral_density(X)
    y = spectral_density(y)
    return ((X - y)**2).mean()


for i in range(NUM_EPOCHS):
    epoch_loss = 0
    for X in dataset:
        optimizer.zero_grad(set_to_none=True)
        assert X.shape == (32, NUM_IFOS, KERNEL_LENGTH * SAMPLE_RATE)
        y = nn(X)

        loss = loss_function(X, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    epoch_loss /= len(dataset)
    print(f"Epoch {i + 1}/{NUM_EPOCHS} Loss: {epoch_loss:0.3e}")

Development

As this library is still very much a work in progress, we anticipate that novel use cases will encounter errors stemming from a lack of robustness. We encourage users who encounter these difficulties to file issues on GitHub, and we'll be happy to offer support to extend our coverage to new or improved functionality. We also strongly encourage ML users in the GW physics space to try their hand at working on these issues and joining on as collaborators! For more information about how to get involved, feel free to reach out to ml4gw@ligo.mit.edu . By bringing in new users with new use cases, we hope to develop this library into a truly general-purpose tool which makes DL more accessible for gravitational wave physicists everywhere.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ml4gw-0.1.0.tar.gz (29.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ml4gw-0.1.0-py3-none-any.whl (32.0 kB view details)

Uploaded Python 3

File details

Details for the file ml4gw-0.1.0.tar.gz.

File metadata

  • Download URL: ml4gw-0.1.0.tar.gz
  • Upload date:
  • Size: 29.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.9 Linux/5.15.0-1031-azure

File hashes

Hashes for ml4gw-0.1.0.tar.gz
Algorithm Hash digest
SHA256 26a431701c7a33a0cff75ce2fe40d791faa023326a3c7be27b157f762e9e24da
MD5 0b1cafcdc732de6ffb719697a73b0b39
BLAKE2b-256 8066b5bc49d76dc710bd4a3998c76adda3cea57a0d4d95e03b2d23875fcee1d5

See more details on using hashes here.

File details

Details for the file ml4gw-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: ml4gw-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 32.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.9 Linux/5.15.0-1031-azure

File hashes

Hashes for ml4gw-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 0be26d5791bd2de92e8f2392039d3f6079c0368fc104036b875b9f55846c152e
MD5 0201b3c31fd80e6eb8f5c09aab0cf410
BLAKE2b-256 aae7c1f604f1cdfa9db2517d4b45a6d713e76e818efe66b500ac73efabeaa8c5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page