Skip to main content

Implementation of Slot Attention in Pytorch

Project description

Slot Attention

Implementation of Slot Attention from the paper 'Object-Centric Learning with Slot Attention' in Pytorch. Here is a video that describes what this network can do.

Update: The official repository has been released here

Install

$ pip install slot_attention

Usage

import torch
from slot_attention import SlotAttention

slot_attn = SlotAttention(
    num_slots = 5,
    dim = 512,
    iters = 3   # iterations of attention, defaults to 3
)

inputs = torch.randn(2, 1024, 512)
slot_attn(inputs) # (2, 5, 512)

After training, the network is reported to be able to generalize to slightly different number of slots (clusters). You can override the number of slots used by the num_slots keyword in forward.

slot_attn(inputs, num_slots = 8) # (2, 8, 512)

To use the adaptive slot method for generating a differentiable one hot mask for whether to use a slot, just do the following

import torch
from slot_attention import MultiHeadSlotAttention, AdaptiveSlotWrapper

# define slot attention

slot_attn = MultiHeadSlotAttention(
    dim = 512,
    num_slots = 5,
    iters = 3,
)

# wrap the slot attention

adaptive_slots = AdaptiveSlotWrapper(
    slot_attn,
    temperature = 0.5 # gumbel softmax temperature
)

inputs = torch.randn(2, 1024, 512)

slots, keep_slots = adaptive_slots(inputs) # (2, 5, 512), (2, 5)

# the auxiliary loss in the paper for minimizing number of slots used for a scene would simply be

keep_aux_loss = keep_slots.sum()  # add this to your main loss with some weight

Citations

@misc{locatello2020objectcentric,
    title   = {Object-Centric Learning with Slot Attention},
    author  = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf},
    year    = {2020},
    eprint  = {2006.15055},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{Fan2024AdaptiveSA,
    title   = {Adaptive Slot Attention: Object Discovery with Dynamic Slot Number},
    author  = {Ke Fan and Zechen Bai and Tianjun Xiao and Tong He and Max Horn and Yanwei Fu and Francesco Locatello and Zheng Zhang},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2406.09196},
    url     = {https://api.semanticscholar.org/CorpusID:270440447}
}
@article{liu2025metaslot,
    title   = {MetaSlot: Break Through the Fixed Number of Slots in Object-Centric Learning},
    author  = {Liu, Hongjia and Zhao, Rongzhen and Chen, Haohan and Pajarinen, Joni},
    journal = {Advances in Neural Information Processing Systems},
    volume  = {38},
    pages   = {67319--67344},
    year    = {2026}
}
@inproceedings{Touska2026OrthoRF,
    title   = {OrthoRF: Exploring Orthogonality in Object-Centric Representations},
    author  = {Despoina Touska and Bastiaan Onne Fagginger Auer and Alexandru Onose and Tejaswi Kasarla and Luis Armando P{\'e}rez Rey and Maximilian Lipp and Lyubov Amitonova and Martin R. Oswald and Pascal Cerfontaine},
    booktitle = {International Conference on Learning Representations},
    year    = {2026}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

slot_attention-1.5.2.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

slot_attention-1.5.2-py2.py3-none-any.whl (10.2 kB view details)

Uploaded Python 2Python 3

File details

Details for the file slot_attention-1.5.2.tar.gz.

File metadata

  • Download URL: slot_attention-1.5.2.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for slot_attention-1.5.2.tar.gz
Algorithm Hash digest
SHA256 a9c084fbd70e019caf3c316912dd656843602407577f814256735c1f28d1975c
MD5 f3167d6473ddb4124166a869706b25c7
BLAKE2b-256 ba24ff9294ee9acf9aea346cead3a1c74e1efb9ad0dcd13c61784c82e83895d6

See more details on using hashes here.

File details

Details for the file slot_attention-1.5.2-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for slot_attention-1.5.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 29f8d56aa9ab0a0f9b98edab8e92d77ca62dc47721629a3313c3ca41c53470db
MD5 07ca1388e2368dbd1c517ee02e4a91d6
BLAKE2b-256 c35a3919828592b0135184775048bd228a6e33f5baf8b70deb24ea0c2fa82965

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page