Skip to main content

LogiTorch is a pytorch-based library for logical reasoning in natural language

Project description

LogiTorch

LogiTorch is a PyTorch-based library for logical reasoning in natural language, it consists of:

  • Textual logical reasoning datasets
  • Implementations of different logical reasoning neural architectures
  • A simple and clean API that can be used with PyTorch Lightning

📦 Installation

📖 Documentation

🖥️ Features

📋 Datasets

Datasets implemented in LogiTorch:

🤖 Models

Models implemented in LogiTorch:

🧪 Example Usage

Training Example

import pytorch_lightning as pl
from logitorch.data_collators.ruletaker_collator import RuleTakerCollator
from logitorch.datasets.qa.ruletaker_dataset import RuleTakerDataset
from logitorch.pl_models.ruletaker import PLRuleTaker
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data.dataloader import DataLoader

train_dataset = RuleTakerDataset("depth-5", "train")
val_dataset = RuleTakerDataset("depth-5", "val")

ruletaker_collate_fn = RuleTakerCollator()
train_dataloader = DataLoader(
    train_dataset, batch_size=32, collate_fn=ruletaker_collate_fn
)

val_dataloader = DataLoader(
    train_dataset, batch_size=32, collate_fn=ruletaker_collate_fn
)

model = PLRuleTaker(learning_rate=1e-5, weight_decay=0.1)

checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="val_loss",
    mode="min",
    dirpath="models/",
    filename="best_ruletaker.ckpt",
)

trainer = pl.Trainer(accelerator="gpu", gpus=1)
trainer.fit(model, train_dataloader, val_dataloader)

Testing Example

from logitorch.pl_models.ruletaker import PLRuleTaker
from logitorch.datasets.qa.ruletaker_dataset import RULETAKER_ID_TO_LABEL
import pytorch_lightning as pl

model = PLRuleTaker.load_from_checkpoint("best_ruletaker.ckpt")

context = "Bob is smart. If someone is smart then he is kind"
question = "Bob is kind"

pred = model.predict(context, question)
print(RULETAKER_ID_TO_LABEL[pred])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

logitorch-0.0.0.tar.gz (35.1 kB view hashes)

Uploaded Source

Built Distribution

logitorch-0.0.0-py3-none-any.whl (54.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page