LogiTorch is a pytorch-based library for logical reasoning in natural language
Project description
LogiTorch
LogiTorch is a PyTorch-based library for logical reasoning in natural language, it consists of:
- Textual logical reasoning datasets
- Implementations of different logical reasoning neural architectures
- A simple and clean API that can be used with PyTorch Lightning
📦 Installation
📖 Documentation
🖥️ Features
📋 Datasets
Datasets implemented in LogiTorch:
- AR-LSAT
- ConTRoL
- LogiQA
- ReClor
- RuleTaker
- ProofWriter
- SNLI
- MultiNLI
- RTE
- Negated SNLI
- Negated MultiNLI
- Negated RTE
- PARARULES Plus
- AbductionRules
🤖 Models
Models implemented in LogiTorch:
🧪 Example Usage
Training Example
import pytorch_lightning as pl
from logitorch.data_collators.ruletaker_collator import RuleTakerCollator
from logitorch.datasets.qa.ruletaker_dataset import RuleTakerDataset
from logitorch.pl_models.ruletaker import PLRuleTaker
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data.dataloader import DataLoader
train_dataset = RuleTakerDataset("depth-5", "train")
val_dataset = RuleTakerDataset("depth-5", "val")
ruletaker_collate_fn = RuleTakerCollator()
train_dataloader = DataLoader(
train_dataset, batch_size=32, collate_fn=ruletaker_collate_fn
)
val_dataloader = DataLoader(
train_dataset, batch_size=32, collate_fn=ruletaker_collate_fn
)
model = PLRuleTaker(learning_rate=1e-5, weight_decay=0.1)
checkpoint_callback = ModelCheckpoint(
save_top_k=1,
monitor="val_loss",
mode="min",
dirpath="models/",
filename="best_ruletaker.ckpt",
)
trainer = pl.Trainer(accelerator="gpu", gpus=1)
trainer.fit(model, train_dataloader, val_dataloader)
Testing Example
from logitorch.pl_models.ruletaker import PLRuleTaker
from logitorch.datasets.qa.ruletaker_dataset import RULETAKER_ID_TO_LABEL
import pytorch_lightning as pl
model = PLRuleTaker.load_from_checkpoint("best_ruletaker.ckpt")
context = "Bob is smart. If someone is smart then he is kind"
question = "Bob is kind"
pred = model.predict(context, question)
print(RULETAKER_ID_TO_LABEL[pred])
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
logitorch-0.0.0.tar.gz
(35.1 kB
view hashes)
Built Distribution
logitorch-0.0.0-py3-none-any.whl
(54.5 kB
view hashes)
Close
Hashes for logitorch-0.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 070fca830e7373537f200d75cc66d1c8dcac750543468951dfffef3561e66bf5 |
|
MD5 | f5db65aaeb27788756bf6530b0223441 |
|
BLAKE2b-256 | e9bcabe7543a7a914b5cbc7d5dfe5c6975e51624ea35af2a222b41d44c35899c |