Implementation of Slot Attention in Pytorch
Project description
Slot Attention
Implementation of Slot Attention from the paper 'Object-Centric Learning with Slot Attention' in Pytorch. Here is a video that describes what this network can do.
Update: The official repository has been released here
Install
$ pip install slot_attention
Usage
import torch
from slot_attention import SlotAttention
slot_attn = SlotAttention(
num_slots = 5,
dim = 512,
iters = 3 # iterations of attention, defaults to 3
)
inputs = torch.randn(2, 1024, 512)
slot_attn(inputs) # (2, 5, 512)
After training, the network is reported to be able to generalize to slightly different number of slots (clusters). You can override the number of slots used by the num_slots keyword in forward.
slot_attn(inputs, num_slots = 8) # (2, 8, 512)
To use the adaptive slot method for generating a differentiable one hot mask for whether to use a slot, just do the following
import torch
from slot_attention import MultiHeadSlotAttention, AdaptiveSlotWrapper
# define slot attention
slot_attn = MultiHeadSlotAttention(
dim = 512,
num_slots = 5,
iters = 3,
)
# wrap the slot attention
adaptive_slots = AdaptiveSlotWrapper(
slot_attn,
temperature = 0.5 # gumbel softmax temperature
)
inputs = torch.randn(2, 1024, 512)
slots, keep_slots = adaptive_slots(inputs) # (2, 5, 512), (2, 5)
# the auxiliary loss in the paper for minimizing number of slots used for a scene would simply be
keep_aux_loss = keep_slots.sum() # add this to your main loss with some weight
Citations
@misc{locatello2020objectcentric,
title = {Object-Centric Learning with Slot Attention},
author = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf},
year = {2020},
eprint = {2006.15055},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@article{Fan2024AdaptiveSA,
title = {Adaptive Slot Attention: Object Discovery with Dynamic Slot Number},
author = {Ke Fan and Zechen Bai and Tianjun Xiao and Tong He and Max Horn and Yanwei Fu and Francesco Locatello and Zheng Zhang},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.09196},
url = {https://api.semanticscholar.org/CorpusID:270440447}
}
@article{liu2025metaslot,
title = {MetaSlot: Break Through the Fixed Number of Slots in Object-Centric Learning},
author = {Liu, Hongjia and Zhao, Rongzhen and Chen, Haohan and Pajarinen, Joni},
journal = {Advances in Neural Information Processing Systems},
volume = {38},
pages = {67319--67344},
year = {2026}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file slot_attention-1.5.0.tar.gz.
File metadata
- Download URL: slot_attention-1.5.0.tar.gz
- Upload date:
- Size: 6.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9b08722be1f916a55652ce2d9875ca1a0f928b0a7e5f99baea6e2dbc98fce1ff
|
|
| MD5 |
638044f8d2804b8918c4b75dbccec166
|
|
| BLAKE2b-256 |
86d3998ffc0a3252d1901510a46895fd5345e4f57d387d8c5533dd64db8b5803
|
File details
Details for the file slot_attention-1.5.0-py2.py3-none-any.whl.
File metadata
- Download URL: slot_attention-1.5.0-py2.py3-none-any.whl
- Upload date:
- Size: 9.2 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.17
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b9d5581237eede93af76d36372380f70abcf87a5550c7182c0252d7394dffed9
|
|
| MD5 |
ff5afabb1783c07349e2796cfb931702
|
|
| BLAKE2b-256 |
f1203c8723d6a05d0bbe316ed35e0828b4026c073bf11f1e0c7e4d140fcd10f7
|