Skip to main content

Labelled Graph Networks for machine learning of crystal.

Project description

CrysNet

GrysNet is a neural network package that allows researchers to train custom models for crystal modeling tasks. It aims to accelerate the research and application of material science.

Table of Contents

Hightlights

  • Easy to installation.
  • Three steps to fast testing.
  • Flexible and adaptive to user's trainning task.

Installation

CrysNet can be installed easily through anaconda! As follows:

  • Create a new conda environment named "crysnet" by command, then activate environment "crysnet":
      conda create -n crysnet python=3.8  
      conda activate crysnet 
  • Configure dependencies of crysnet:
      conda install tensorflow-gpu==2.6.0  # for CPU conda install tensorflow==2.6.0

If your conda can't find tensorflow-gpu==2.6.0, you can add a new source, e.g.:

      conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/  
      conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/  
  • Install pymatgen:
      conda install --channel conda-forge pymatgen  
  • Install other dependencies:
      pip install atom2vec  
      pip install mendeleev  
      conda install graphviz # or pip install graphviz  
      conda install pydot # or pip install pydot  
  • Install crysnet:
      pip install crysnet  

Usage

Fast testing soon

CrysNet is very easy to use!
Just three steps can finish a fast test using crysnet:

  • download test data
    Get test datas from https://github.com/huzongxiang/CrysNetwork/datas/
    There are three json files in datas: dataset_classification.json, dataset_multiclassification.json and dataset_regression.json.
  • prepare workdir
    Download datas and put it in your trainning work directory, test.py file should also be put in the directory
  • run command
    run command:
      python test.py  

You have finished your testing multi-classification trainning! The trainning results and model weight could be saved in /results and /models, respectively.

Understanding trainning script

You can use crysnet by provided trainning scripts in user_easy_trainscript only, but understanding script will help you custom your trainning task!

  • get datas
    Get current work directory of running trainning script, the script will read datas from 'workdir/datas/' , then saves results and models to 'workdir/results/' and 'workdir/models/'
from pathlib import Path
ModulePath = Path(__file__).parent.absolute() # workdir
  • fed trainning datas
    Module Dataset will read data from 'ModulePath/datas/dataset.json', 'task_type' defines regression/classification/multi-classification, 'data_path' gets path of trainning datas.
from crysnet.data import Dataset
dataset = Dataset(task_type='multiclassfication', data_path=ModulePath)
  • generator
    Module GraphGenerator feds datas into model during trainning. The Module splits datas into train, valid, test sets, and transform structures data into labelled graphs and gets three generators. BATCH_SIZE is batch size during trainning, DATA_SIZE defines number of datas your used in entire datas, CUTOFF is cutoff of graph edges in crystal.
from crysnet.data.generator import GraphGenerator
BATCH_SIZE = 128
DATA_SIZE = None
CUTOFF = 2.5
Generators = GraphGenerator(dataset, data_size=DATA_SIZE, batch_size=BATCH_SIZE, cutoff=CUTOFF)
train_data = Generators.train_generator
valid_data = Generators.valid_generator
test_data = Generators.test_generator

#if task is multiclassfication, should define variable multiclassifiction
multiclassification = Generators.multiclassification  
  • building model
    Module GNN defines a trainning framework that accepts a series of self-difined models. We provide GraphModel, MpnnBaseModel, TransformerBaseModel , MpnnModel, TransformerModel, DirectionalMpnnModel, DirectionalTransformerModel and CGCNN model according to your demends. TransformerModel, GraphModel and MpnnModel are different models. TransformerModel is a graph transformer. MpnnModel is a massege passing neural network. GraphModel is a combination of TransformerModel and MpnnModel. MpnnBaseModel and TransformerBaseModel don't take directional informations of crystal into count so them run faster. MpnnBaseModel is the fastest model but accuracy is enough for most tasks. TransformerModel can achieve the hightest accuracy in most tasks. The CGCNN model is crystal graph convolution CNN model.
from crysnet.models import GNN
from crysnet.models.graphmodel import GraphModel, MpnnBaseModel, TransformerBaseModel , MpnnModel, TransformerModel, DirectionalMpnnModel, DirectionalTransformerModel 
gnn = GNN(model=MpnnBaseModel,
      atom_dim=16
      bond_dim=64
      num_atom=118
      state_dim=16
      sp_dim=230
      units=32
      edge_steps=1
      message_steps=1
      transform_steps=1
      num_attention_heads=8
      dense_units=64
      output_dim=64
      readout_units=64
      dropout=0.0
      reg0=0.00
      reg1=0.00
      reg2=0.00
      reg3=0.00
      reg_rec=0.00
      batch_size=BATCH_SIZE
      spherical_harmonics=True
      regression=dataset.regression
      optimizer = 'Adam'
      )
  • trainning
    Using trainning function of model to train. Common trainning parameters can be defined, workdir is current directory of trainning script, it saves results of model during trainning. If test_data exists, model will predict on test_data.
gnn.train(train_data, valid_data, test_data, epochs=700, lr=3e-3, warm_up=True, load_weights=False, verbose=1, checkpoints=None, save_weights_only=True, workdir=ModulePath)
  • prediction
    The simplest method for predicting is using script predict.py in /user_easy_train_scripts.
    Using predict_data funciton to predict.
gnn.predict_datas(test_data, workdir=ModulePath)    # predict on test datas with labels
y_pred_keras = gnn.predict(datas)                   # predict on new datas without labels
  • preparing your custom datas
    If you have your structures (and labels), the Dataset receives pymatgen.core.Structure type. So you should transform your POSCAR or cif to pymatgen.core.Structure type.
import os
from pymatgen.core.structure import Structure
structures = []                                      # your structure list
for cif in os.listdir(cif_path):
      structures.append(Structure.from_file(cif))    # for POSCAR too

# construct your dataset
from crysnet.data import Dataset
dataset = Dataset(task_type='my_classification', data_path=ModulePath)  # task_type could be my_regression, my_classification, my_multiclassification
dataset.prepare_x(structures)
dataset.prepare_y(labels)   # if you have labels used to trainning model, labels could be None in prediction on new datas without labels

# alternatively, you can construct dataset as follow
dataset.structures = structures
dataset.labels = labels

# save your structures and labels to dataset in dataset_my*.json
dataset.save_datasets(strurtures, labels)

# for prediction on new datas without labels, Generators has not attribute multiclassification, should assign definite value
Generators = GraphGenerator(dataset, data_size=DATA_SIZE, batch_size=BATCH_SIZE, cutoff=CUTOFF)     # dataset.labels is None
Generators.multiclassification = 5
multiclassification = Generators.multiclassification  # multiclassification = 5
      
  • custom your model and trainning
    The Module GNN provides a flexible trainning framework to accept tensorflow.keras.models.Model type customized by user. Yon can custom your model and train the model according to the following example.
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from crysnet.layers import MessagePassing
from crysnet.layers import PartitionPadding

def MyModel(
            bond_dim,
            atom_dim=16,
            num_atom=118,
            state_dim=16,
            sp_dim=230,
            units=32,
            message_steps=1,
            readout_units=64,
            batch_size=16,
            regression=False,
            multiclassification=None,
            ):
            atom_features = layers.Input((), dtype="int32", name="atom_features_input")
            atom_features_ = layers.Embedding(num_atom, atom_dim, dtype="float32", name="atom_features")(atom_features)
            bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features")
            local_env = layers.Input((6), dtype="float32", name="local_env")
            state_attrs = layers.Input((), dtype="int32", name="state_attrs_input")   
            state_attrs_ = layers.Embedding(sp_dim, state_dim, dtype="float32", name="state_attrs")(state_attrs)

            pair_indices = layers.Input((2), dtype="int32", name="pair_indices")

            atom_graph_indices = layers.Input(
            (), dtype="int32", name="atom_graph_indices"
            )

            bond_graph_indices = layers.Input(
            (), dtype="int32", name="bond_graph_indices"
            )

            pair_indices_per_graph = layers.Input((2), dtype="int32", name="pair_indices_per_graph")

            x = MessagePassing(message_steps)(
            [atom_features_, edge_features, state_attrs_, pair_indices,
                  atom_graph_indices, bond_graph_indices]
            )

            x = x[0]

            x = PartitionPadding(batch_size)([x, atom_graph_indices])

            x = layers.BatchNormalization()(x)

            x = layers.GlobalAveragePooling1D()(x)

            x = layers.Dense(readout_units, activation="relu", name='readout0')(x)

            x = layers.Dense(readout_units//2, activation="relu", name='readout1')(x)

            if regression:
            x = layers.Dense(1, name='final')(x)
            elif multiclassification is not None:
            x = layers.Dense(multiclassification, activation="softmax", name='final_softmax')(x)
            else:
            x = layers.Dense(1, activation="sigmoid", name='final')(x)

            model = Model(
            inputs=[atom_features, bond_features, local_env, state_attrs, pair_indices, atom_graph_indices,
                        bond_graph_indices, pair_indices_per_graph],
            outputs=[x],
            )
            return model

from crysnet.models import GNN
gnn = GNN(model=MyModel,        
            atom_dim=16,
            bond_dim=64,
            num_atom=118,
            state_dim=16,
            sp_dim=230,
            units=32,
            message_steps=1,
            readout_units=64,
            batch_size=16,
            optimizer='Adam',
            regression=False,
            multiclassification=None,)
gnn.train(train_data, valid_data, test_data, epochs=700, lr=3e-3, warm_up=True, load_weights=False, verbose=1, checkpoints=None, save_weights_only=True, workdir=ModulePath)
  You can set edge as your model output.
from crysnet.layers import EdgeMessagePassing
def MyModel(
            bond_dim,
            atom_dim=16,
            num_atom=118,
            state_dim=16,
            sp_dim=230,
            units=32,
            message_steps=1,
            readout_units=64,
            batch_size=16,
            regression=False,
            multiclassification=None,
            ):
            atom_features = layers.Input((), dtype="int32", name="atom_features_input")
            atom_features_ = layers.Embedding(num_atom, atom_dim, dtype="float32", name="atom_features")(atom_features)
            bond_features = layers.Input((bond_dim), dtype="float32", name="bond_features")
            local_env = layers.Input((6), dtype="float32", name="local_env")
            state_attrs = layers.Input((), dtype="int32", name="state_attrs_input")   
            state_attrs_ = layers.Embedding(sp_dim, state_dim, dtype="float32", name="state_attrs")(state_attrs)

            pair_indices = layers.Input((2), dtype="int32", name="pair_indices")

            atom_graph_indices = layers.Input(
            (), dtype="int32", name="atom_graph_indices"
            )

            bond_graph_indices = layers.Input(
            (), dtype="int32", name="bond_graph_indices"
            )

            pair_indices_per_graph = layers.Input((2), dtype="int32", name="pair_indices_per_graph")

            x = EdgeMessagePassing(units,
                                    edge_steps,
                                    kernel_regularizer=l2(reg0),
                                    sph=spherical_harmonics
                                    )([bond_features, local_env, pair_indices])

            x = x[1]

            x = PartitionPadding(batch_size)([x, bond_graph_indices])

            x = layers.BatchNormalization()(x)

            x = layers.GlobalAveragePooling1D()(x)

            x = layers.Dense(readout_units, activation="relu", name='readout0')(x)

            x = layers.Dense(readout_units//2, activation="relu", name='readout1')(x)

            if regression:
            x = layers.Dense(1, name='final')(x)

            model = Model(
            inputs=[atom_features, bond_features, local_env, state_attrs, pair_indices, atom_graph_indices,
                        bond_graph_indices, pair_indices_per_graph],
            outputs=[x],
            )
            return model
  The Module GNN has some basic parameter necessary to be defined but not necessary to be used:
class GNN:
      def __init__(self,
            model: Model,
            atom_dim=16,
            bond_dim=32,
            num_atom=118,
            state_dim=16,
            sp_dim=230,
            batch_size=16,
            regression=True,
            optimizer = 'Adam',
            multiclassification=None,
            **kwargs,
            ):
            """
            pass
            """

Framework

CrysNet

Contributors

Zongxiang Hu

References

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

crysnet-0.1.5.tar.gz (2.7 MB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page