Skip to main content

Predict materials properties using only the composition information.

Project description

Compositionally-Restricted Attention-Based Network (CrabNet)

This software package implements the Compositionally-Restricted Attention-Based Network (CrabNet) that takes only composition information to predict material properties.

Table of Contents

  • How to cite
  • Installation
  • Reproduce publication results
  • Train or predict materials properties using CrabNet or DenseNet

How to cite

Please cite the following work if you want to use CrabNet:

@article{Wang2021crabnet,
 author = {Wang, Anthony Yu-Tung and Kauwe, Steven K. and Murdock, Ryan J. and Sparks, Taylor D.},
 year = {2021},
 title = {Compositionally restricted attention-based network for materials property predictions},
 pages = {77},
 volume = {7},
 number = {1},
 doi = {10.1038/s41524-021-00545-1},
 publisher = {{Nature Publishing Group}},
 shortjournal = {npj Comput. Mater.},
 journal = {npj Computational Materials}
}

Installation

This code uses PyTorch for creating the neural network models. For fast model training and inference, it is suggested you use an NVIDIA GPU with the most recent drivers.

Windows users should be able to install all required Python packages via Anaconda by following the steps below.

Linux users will additionally need to manually install CUDA and cuDNN.

Clone or download this GitHub repository

Do one of the following:

Install dependencies via Anaconda:

  1. Download and install Anaconda.
  2. Navigate to the project directory (from above).
  3. Open Anaconda prompt in this directory.
  4. Run the following command from Anaconda prompt to automatically create an environment from the conda-env.yml file:
    • conda env create --file conda-env.yml
    • conda env create --file conda-env-cpuonly.yml if you only have a CPU and no GPU in your system
  5. Run the following command from Anaconda prompt to activate the environment:
    • conda activate crabnet

For more information about creating, managing, and working with Conda environments, please consult the relevant help page.

Install dependencies via pip:

Open conda-env.yml and pip install all of the packages listed there. We recommend that you create a separate Python environment for this project.

Reproduce publication results

To reproduce the publication results, please follow the below steps. Results will slightly vary. It is a known phenomena that PyTorch model training may slightly vary on different computers and hardware.

Trained weights are provided at: http://doi.org/10.5281/zenodo.4633866.

As a reference, with a desktop computer with an IntelTM i9-9900K processor, 32GB of RAM, and two NVIDIA RTX 2080 Ti's, training our largest network (OQMD) takes roughly two hours.

Train CrabNet

  1. To train crabnet you need train.csv, val.csv, and optionally a test.csv files.
    • train.csv is used to find model weights.
    • val.csv ensures the model does not overfit.
    • test.csv will be run on the trained model for performance evaluation.
  2. Place the csv files in the data/materials_data directory.
    • The csv file must contain two columns, formula and target.
    • formula must be a string containing valid element symbols, numbers, and parentheses.
    • target is the target material property and should be provided as a number.
    • Additional csv files can be saved here. In the case of inference with no known targets, you may fill the target columns with 0's.
  3. Run train_crabnet.py to train CrabNet using default parameters.
    • If you desire to perform inference with additional csv files, you may add code to train_crabnet.py of the form
    _, mae_added_data = save_results(data_dir, mat_prop, classification,
                                     'my_added_data.csv', verbose=False)
    
  4. Note that your trained network will be associated with your given mat_prop folder. If you want to predict with this model, you must use the same mat_prop.

Plot results

  1. Inference outputs using the provided saved weights are in the predictions folder.
  2. Data are in the folder publication_predictions
  3. Run Paper_{FIG|TABLE}_{X}.py to produce the tables and figures shown in the manuscript.

IMPORTANT - if you want to reproduce the publication Figures 1 and 2:

The PyTorch-builtin function for outting the multi-headed attention operation defaults to averaging the attention matrix across all heads. Thus, in order to obtain the per-head attention information, we have to edit a bit of PyTorch's source code so that the individual attention matrices are returned.

To properly export the attention heads from the PyTorch nn.MultiheadAttention implementation within the transformer encoder layer, you will need to manually modify some of the source code of the PyTorch library. This applies to PyTorch v1.6.0, v1.7.0, and v1.7.1 (potentially to other untested versions as well).

For this, open the file: C:\Users\{USERNAME}\Anaconda3\envs\{ENVIRONMENT}\Lib\site-packages\torch\nn\functional.py (where USERNAME is your Windows user name and ENVIRONMENT is your conda environment name (if you followed the steps above, then it should be crabnet))

At the end of the function defition of multi_head_attention_forward (line numbers may differ slightly):

L4011 def multi_head_attention_forward(
# ...
# ... [some lines omitted]
# ...
L4291    if need_weights:
L4292        # average attention weights over heads
L4293        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
L4294        return attn_output, attn_output_weights.sum(dim=1) / num_heads
L4295    else:
L4296        return attn_output, None

Change the specific line

return attn_output, attn_output_weights.sum(dim=1) / num_heads

to:

return attn_output, attn_output_weights

This prevents the returning of the attention values as an average value over all heads, and instead returns each head's attention matrix individually. For more information see:

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crabnet-1.0.0.tar.gz (46.6 MB view details)

Uploaded Source

Built Distribution

crabnet-1.0.0-py2.py3-none-any.whl (34.5 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file crabnet-1.0.0.tar.gz.

File metadata

  • Download URL: crabnet-1.0.0.tar.gz
  • Upload date:
  • Size: 46.6 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.25.1

File hashes

Hashes for crabnet-1.0.0.tar.gz
Algorithm Hash digest
SHA256 6d3ca3cab9a44b38661f68a8cb3f8d8a742e3f7a82bf327749b31cb1f8fcb923
MD5 2d22c6235736da8a914d927696a73e40
BLAKE2b-256 751082adf47ba8dbd44cf9731969676f0ecf3919d5cd1d3205ceacc2cacb3dd0

See more details on using hashes here.

Provenance

File details

Details for the file crabnet-1.0.0-py2.py3-none-any.whl.

File metadata

  • Download URL: crabnet-1.0.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 34.5 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-requests/2.25.1

File hashes

Hashes for crabnet-1.0.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 7f4ef7c8cdeae5d917a214587fd017560d2a06fe03e6edbb61a41063407974e1
MD5 f01c12bc97053f791ff7be685b771989
BLAKE2b-256 c60ffce6a0477abbd1e1155ed5801a601879f3a2c9b2c66fd5226a60d8d057db

See more details on using hashes here.

Provenance

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page