Skip to main content

Molcule Transformer X Model

Project description

MolTx

CI Coverage Status PyPI - Python Version

Installation

pip install moltx

Usage

Pretrain

import torch

# prepare dataset
from moltx import datasets, tokenizers, models
ds = datasets.AdaMR2(device=torch.device('cpu'))
generic_smiles = ["C=CC=CC=C", "...."]
canonical_smiles = ["c1cccc1c", "..."]
tgt, out = ds(generic_smiles, canonical_smiles)

# train
import torch.nn as nn
from torch.optim import Adam
from moltx import nets, models

## use custom config
conf = models.AdaMR2.CONFIG_LARGE # or models.AdaMR2.CONFIG_BASE
model = models.AdaMR2(conf)

crt = nn.CrossEntropyLoss(ignore_index=0)
optim = Adam(model.parameters(), lr=0.1)

optim.zero_grad()
pred = model(tgt)
loss = crt(pred.view(-1, pred.size(-1)), out.view(-1))
loss.backward()
optim.step()

# save ckpt
torch.save(model.state_dict(), '/path/to/adamr.ckpt')

Finetune

# Classifier finetune
from moltx import datasets

seq_len = 256 # max token lens of smiles in datasets, if None, use max token lens in smiles
ds = datasets.AdaMR2Classifier(device=torch.device('cpu'))
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"]
labels = [0, 1]
tgt, out = ds(smiles, labels, seq_len)

from moltx import nets, models
pretrained_conf = models.AdaMR.CONFIG_LARGE # or models.AdaMR.CONFIG_BASE
model = models.AdaMR2Classifier(num_classes=2, conf=pretrained_conf)
model.load_ckpt('/path/to/adamr.ckpt')
crt = nn.CrossEntropyLoss()
optim = Adam(model.parameters(), lr=0.1)

optim.zero_grad()
pred = model(tgt)
loss = crt(pred, out)
loss.backward()
optim.step()

torch.save(model.state_dict(), '/path/to/classifier.ckpt')

# Regression finetune
ds = datasets.AdaMR2Regression(device=torch.device('cpu'))
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"]
values = [0.23, 0.12]
tgt, out = ds(smiles, values, seq_len)

model = models.AdaMR2Regression(conf=pretrained_conf)
model.load_ckpt('/path/to/adamr.ckpt')
crt = nn.MSELoss()

optim.zero_grad()
pred = model(tgt)
loss = crt(pred, out)
loss.backward()
optim.step()

torch.save(model.state_dict(), '/path/to/regression.ckpt')

# Distributed Generation
ds = datasets.AdaMR2DistGeneration(device=torch.device('cpu'))
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"]
tgt, out = ds(smiles, seq_len)

model = models.AdaMR2DistGeneration(conf=pretrained_conf)
model.load_ckpt('/path/to/adamr.ckpt')
crt = nn.CrossEntropyLoss(ignore_index=0)

optim.zero_grad()
pred = model(tgt)
loss = crt(pred.view(-1, pred.size(-1)), out.view(-1))
loss.backward()
optim.step()

torch.save(model.state_dict(), '/path/to/distgen.ckpt')

# Goal Generation
ds = datasets.AdaMR2GoalGeneration(device=torch.device('cpu'))
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"]
goals = [0.23, 0.12]
tgt, out = ds(smiles, goals, seq_len)

model = models.AdaMR2GoalGeneration(conf=pretrained_conf)
model.load_ckpt('/path/to/adamr.ckpt')
crt = nn.CrossEntropyLoss(ignore_index=0)

optim.zero_grad()
pred = model(src, tgt)
loss = crt(pred.view(-1, pred.size(-1)), out.view(-1))
loss.backward()
optim.step()

torch.save(model.state_dict(), '/path/to/goalgen.ckpt')

Inference

from moltx import nets, models, pipelines, tokenizers
# AdaMR
conf = models.AdaMR2.CONFIG_LARGE # or models.AdaMR.CONFIG_BASE
model = models.AdaMR2(conf)
model.load_ckpt('/path/to/adamr.ckpt')
pipeline = pipelines.AdaMR2(model)
pipeline("C=CC=CC=C")
# {"smiles": ["c1ccccc1"], probabilities: [0.9]}

# Classifier
conf = models.AdaMR2.CONFIG_LARGE # or models.AdaMR.CONFIG_BASE
model = models.AdaMR2Classifier(2, conf)
model.load_ckpt('/path/to/classifier.ckpt')
pipeline = pipelines.AdaMR2Classifier(model)
pipeline("C=CC=CC=C")
# {"label": [1], "probability": [0.67]}

# Regression
conf = models.AdaMR2.CONFIG_LARGE # or models.AdaMR.CONFIG_BASE
model = models.AdaMR2Regression(2, conf)
model.load_ckpt('/path/to/regression.ckpt')
pipeline = pipelines.AdaMR2Regression(model)
pipeline("C=CC=CC=C")
# {"value": [0.467], "probability": [0.67]}

# DistGeneration
conf = models.AdaMR2.CONFIG_LARGE # or models.AdaMR.CONFIG_BASE
model = models.AdaMR2DistGeneration(conf)
model.load_ckpt('/path/to/distgen.ckpt')
pipeline = pipelines.AdaMR2DistGeneration(model)
pipeline(k=2)
# {"smiles": ["c1ccccc1", "...."], probabilities: [0.9, 0.1]}

# GoalGeneration
conf = models.AdaMR2.CONFIG_LARGE # or models.AdaMR.CONFIG_BASE
model = models.AdaMR2GoalGeneration(conf)
model.load_ckpt('/path/to/goalgen.ckpt')
pipeline = pipelines.AdaMRGoalGeneration(model)
pipeline(0.48, k=2)
# {"smiles": ["c1ccccc1", "...."], probabilities: [0.9, 0.1]}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

moltx-2.0.0.tar.gz (43.7 kB view details)

Uploaded Source

Built Distribution

moltx-2.0.0-py3-none-any.whl (41.5 kB view details)

Uploaded Python 3

File details

Details for the file moltx-2.0.0.tar.gz.

File metadata

  • Download URL: moltx-2.0.0.tar.gz
  • Upload date:
  • Size: 43.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for moltx-2.0.0.tar.gz
Algorithm Hash digest
SHA256 9a1dc7c141d3ee9a94e0e1cd18bd55420e96c7258bbd8fcbf1baeb134cac7d50
MD5 a41e0ce97a1ef52225669fb77cbcd37d
BLAKE2b-256 c1023d833b7c097b2bc45de14bd7922cd35ecbba876e1e17b97a8711744be34d

See more details on using hashes here.

File details

Details for the file moltx-2.0.0-py3-none-any.whl.

File metadata

  • Download URL: moltx-2.0.0-py3-none-any.whl
  • Upload date:
  • Size: 41.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.6

File hashes

Hashes for moltx-2.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 446ecd4c3e215f4e25c3313ab89b9ee1520e76e35518c5ca430534b8a6aa3e00
MD5 e4f6459a1023cc3c4b78baddc4132120
BLAKE2b-256 7796ed0db05546e141f3b97f4de9028994de665221e6db612d1b68e659c6902e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page