Visual arithmetic reasoning with Machine Number Sense dataset
Project description
Machine Number Sense
PyTorch implementation of neural networks for solving problems from the Machine Number Sense (MNS) dataset [1]. Dataset and official implementation of baseline models can be found in this repo, created by paper authors.
Setup
$ pip install machine_number_sense
Usage
Baseline models
MLP [1]:
import torch
from mns.model import ConvMLP
x = torch.rand(4, 3, 80, 80)
mlp = ConvMLP(image_size=80)
logits = mlp(x)
logits # torch.Tensor with shape (4, 99)
LSTM [1]:
import torch
from mns.model import ConvLSTM
x = torch.rand(4, 3, 80, 80)
lstm = ConvLSTM(image_size=80)
logits = lstm(x)
logits # torch.Tensor with shape (4, 99)
Experimental models
Scattering Compositional Learner (SCL) [2] adapted to problems from the MNS dataset:
import torch
from mns.model import SCL
x = torch.rand(4, 3, 80, 80)
scl = SCL(image_size=80)
logits = scl(x)
logits # torch.Tensor with shape (4, 99)
Implementation of SCL for solving Raven's Progressive Matrices can be found in this repo.
Neural Arithmetic Logic Unit (NALU) [3] adapted to MNS:
import torch
from mns.model import ConvNALU
x = torch.rand(4, 3, 80, 80)
nalu = ConvNALU(image_size=80)
logits = nalu(x)
logits # torch.Tensor with shape (4, 99)
Dataset
The MNS dataset can be obtained as described in this repo. After downloading, it can be loaded with:
from mns.dataset import MNSDataset
dataset = MNSDataset(data_dir='/path/to/dataset', image_size=80)
iterator = iter(dataset)
image, target = next(iterator)
image # torch.Tensor with shape (3, 80, 80)
target # torch.Tensor with shape ()
Training
File mns.module
contains a PyTorch Lightning module for training models on MNS.
Training can be run with Docker using scripts from the scripts/
directory.
Unit tests
$ python -m pytest tests
Bibliography
[1] Zhang, Wenhe, et al. "Machine number sense: A dataset of visual arithmetic problems for abstract and relational reasoning." Proceedings of the AAAI Conference on Artificial Intelligence. 2020.
[2] Wu, Yuhuai, et al. "The Scattering Compositional Learner: Discovering Objects, Attributes, Relationships in Analogical Reasoning." arXiv preprint arXiv:2007.04212 (2020).
[3] Trask, Andrew, et al. "Neural arithmetic logic units." Advances in Neural Information Processing Systems. 2018.
Citations
@inproceedings{zhang2020machine,
title={Machine number sense: A dataset of visual arithmetic problems for abstract and relational reasoning},
author={Zhang, Wenhe and Zhang, Chi and Zhu, Yixin and Zhu, Song-Chun},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={34},
number={02},
pages={1332--1340},
year={2020}
}
@article{wu2020scattering,
title={The Scattering Compositional Learner: Discovering Objects, Attributes, Relationships in Analogical Reasoning},
author={Wu, Yuhuai and Dong, Honghua and Grosse, Roger and Ba, Jimmy},
journal={arXiv preprint arXiv:2007.04212},
year={2020}
}
@inproceedings{trask2018neural,
title={Neural arithmetic logic units},
author={Trask, Andrew and Hill, Felix and Reed, Scott E and Rae, Jack and Dyer, Chris and Blunsom, Phil},
booktitle={Advances in Neural Information Processing Systems},
pages={8035--8044},
year={2018}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file machine_number_sense-0.1.0.tar.gz
.
File metadata
- Download URL: machine_number_sense-0.1.0.tar.gz
- Upload date:
- Size: 7.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.54.0 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0d6ec496f451deed43a095fe7ddac315cee244160e9381739180bf0a8d9823b2 |
|
MD5 | ec1ac75c2fbf2df5bc1ba7caaad6d039 |
|
BLAKE2b-256 | e72bbabf903e36276c52c3f809054592bb0215ca6e4ab92369e80d6532f3f1b8 |
File details
Details for the file machine_number_sense-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: machine_number_sense-0.1.0-py3-none-any.whl
- Upload date:
- Size: 10.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.6.1 requests/2.25.0 setuptools/49.2.1 requests-toolbelt/0.9.1 tqdm/4.54.0 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 07d1b06b7aa347f43857c551bc017eae5ee2d7e7003cb633c54b2212379b4dd9 |
|
MD5 | 495bce55e72b420c4466508af18b7afd |
|
BLAKE2b-256 | 58e0159031ddd22241b9c715bbe4b0321cb433349047243d29c29695b05c7036 |