Wild Relation Network for solving Raven's Progressive Matrices
Project description
Wild Relation Network
PyTorch implementation of Relation Network [1] and Wild Relation Network [2] for solving Raven's Progressive Matrices.
Setup
$ pip install wild_relation_network
Usage
Relation Network:
import torch
from wild_relation_network import RelationNetwork
x = torch.rand(4, 8, 64)
rn = RelationNetwork(
num_objects=8,
object_size=64,
out_size=32,
use_object_triples=False,
use_layer_norm=False
)
logits = rn(x)
logits # torch.Tensor with shape (4, 32)
Wild Relation Network:
import torch
from wild_relation_network import WReN
x = torch.rand(4, 16, 160, 160)
wren = WReN(
num_channels=32,
use_object_triples=False,
use_layer_norm=False
)
logits = wren(x)
y_hat = logits.log_softmax(dim=-1)
y_hat # torch.Tensor with shape (4, 8)
Unit tests
$ python -m pytest tests
Bibliography
[1] Santoro, Adam, et al. "A simple neural network module for relational reasoning." Advances in neural information processing systems. 2017.
[2] Santoro, Adam, et al. "Measuring abstract reasoning in neural networks." International Conference on Machine Learning. 2018.
Citations
@inproceedings{santoro2017simple,
title={A simple neural network module for relational reasoning},
author={Santoro, Adam and Raposo, David and Barrett, David G and Malinowski, Mateusz and Pascanu, Razvan and Battaglia, Peter and Lillicrap, Timothy},
booktitle={Advances in neural information processing systems},
pages={4967--4976},
year={2017}
}
@inproceedings{santoro2018measuring,
title={Measuring abstract reasoning in neural networks},
author={Santoro, Adam and Hill, Felix and Barrett, David and Morcos, Ari and Lillicrap, Timothy},
booktitle={International Conference on Machine Learning},
pages={4477--4486},
year={2018}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for wild_relation_network-0.1.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 35eeed78ccf327f2ca62cac4a34b18367f70d54236823858a258b8cc5c1d2f09 |
|
MD5 | 0bc4c5ddc4516d99c9ee04d5cc74f62e |
|
BLAKE2b-256 | 251d8cf481264db60d1e8cc76bd19abb165a65c7ef928bb56135c3d696a67270 |
Close
Hashes for wild_relation_network-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d5115c9c986af4462520a51cb2ff09aa2835485c8e292a321b1d9049e170156d |
|
MD5 | 1291bdc6433e12e1747635f3653f7b4e |
|
BLAKE2b-256 | 82131d366a1d5f4803e823eb3c5adce1e2795f54b22f9473646dd10fe68e13f5 |