Skip to main content

GODM

Project description

GODM

PyPI version

GODM is a data augmentation package for supervised graph outlier detection. It generates synthetic graph outliers with latent diffusion models. This is the official implementation of Data Augmentation for Supervised Graph Outlier Detection with Latent Diffusion Models.

model architecture

Installation

It is recommended to use pip for installation:

pip install godm

Alternatively, you can build from source by cloning this repository:

git clone https://github.com/kayzliu/godm.git
cd pygod
pip install .

Usage

from pygod.utils import load_data
data = load_data('weibo') # load data

from godm import GODM     # import GODM
godm = GODM(lr=0.004)     # init. GODM
aug_data = godm(data)     # augment data

detector(aug_data)        # train on data

The input data should be torch_geometric.Data object with the following keys:

  • x: node features,
  • edge_index: edge index,
  • edge_time: edge times (optional, name can be changed by time_attr),
  • edge_type: edge types (optional, name can be changed by type_attr),
  • y: node labels,
  • train_mask: training node mask,
  • val_mask: validation node mask,
  • test_mask: testing node mask.

So far, no additional keys is allowed. We may support more keys by padding in the future.

Parameters

  • hid_dim (type: int, default: None): hidden dimension for VAE, i.e., latent embedding dimension. None means the largest power of 2 that is less than or equal to the feature dimension divided by two.
  • diff_dim (type: int, default: None): hidden dimension for denoiser. None means as twice as hid_dim.
  • vae_epochs (type: int, default: 100): number of epochs for training VAE.
  • diff_epochs (type: int, default: 100): number of epochs for training diffusion model.
  • patience (type: int, default: 50): patience for early stopping.
  • lr (type: float, default: 0.001): learning rate.
  • wd (type: float, default: 0.): weight decay.
  • batch_size (type: int, default: 2048): batch size.
  • threshold (type: float, default: 0.75): threshold for edge generation.
  • wx (type: float, default: 1.): weight for node feature reconstruction loss.
  • we (type: float, default: 0.5): weight for edge reconstruction loss.
  • beta (type: float, default: 0.001): weight for KL divergence loss.
  • wt (type: float, default: 1.): weight for time prediction loss.
  • time_attr (type: str, default: edge_time): attribute name for edge time.
  • type_attr (type: str, default: edge_type): attribute name for edge type.
  • wp (type: float, default: 0.3): weight for node prediction loss.
  • gen_nodes (type: int, default: None): number of nodes to generate. None means the same as the number of outliers in the original graph.
  • sample_steps (type: int, default: 50): number of steps for diffusion model sampling.
  • device (type: int, default: 0): GPU index, set to -1 for CPU.
  • verbose (type: bool, default: False): verbose mode, enable for logging.

Cite Us:

Our paper is publicly available. If you use GODM in a scientific publication, we would appreciate your citations:

@article{liu2023data,
  title={Data Augmentation for Supervised Graph Outlier Detection with Latent Diffusion Models},
  author={Liu, Kay and Zhang, Hengrui and Hu, Ziqing and Wang, Fangxin and Yu, Philip S.},
  journal={arXiv preprint arXiv:2312.17679},
  year={2023}
}

or:

Liu, K., Zhang, H., Hu, Z., Wang, F., and Yu, P.S. 2023. Data Augmentation for Supervised Graph Outlier Detection with Latent Diffusion Models. arXiv preprint arXiv:2312.17679.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

godm-0.1.0.tar.gz (16.5 kB view details)

Uploaded Source

File details

Details for the file godm-0.1.0.tar.gz.

File metadata

  • Download URL: godm-0.1.0.tar.gz
  • Upload date:
  • Size: 16.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.9

File hashes

Hashes for godm-0.1.0.tar.gz
Algorithm Hash digest
SHA256 35d090492c803bc531c46df32d9d097273fa6ec4627f8679847d00f381b6d2ff
MD5 3ca24904215fd42d09a6c3b9466daf96
BLAKE2b-256 55863677691bb0dd9b2fdb3eb203e7ce1b8b2b87d364cf60276f3dfa58d1f404

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page