Skip to main content

Implementation of Slot Attention in Pytorch

Project description

Slot Attention

Implementation of Slot Attention from the paper 'Object-Centric Learning with Slot Attention' in Pytorch. Here is a video that describes what this network can do.

Update: The official repository has been released here

Install

$ pip install slot_attention

Usage

import torch
from slot_attention import SlotAttention

slot_attn = SlotAttention(
    num_slots = 5,
    dim = 512,
    iters = 3   # iterations of attention, defaults to 3
)

inputs = torch.randn(2, 1024, 512)
slot_attn(inputs) # (2, 5, 512)

After training, the network is reported to be able to generalize to slightly different number of slots (clusters). You can override the number of slots used by the num_slots keyword in forward.

slot_attn(inputs, num_slots = 8) # (2, 8, 512)

To use the adaptive slot method for generating a differentiable one hot mask for whether to use a slot, just do the following

import torch
from slot_attention import MultiHeadSlotAttention, AdaptiveSlotWrapper

# define slot attention

slot_attn = MultiHeadSlotAttention(
    dim = 512,
    num_slots = 5,
    iters = 3,
)

# wrap the slot attention

adaptive_slots = AdaptiveSlotWrapper(
    slot_attn,
    temperature = 0.5 # gumbel softmax temperature
)

inputs = torch.randn(2, 1024, 512)

slots, keep_slots = adaptive_slots(inputs) # (2, 5, 512), (2, 5)

# the auxiliary loss in the paper for minimizing number of slots used for a scene would simply be

keep_aux_loss = keep_slots.sum()  # add this to your main loss with some weight

Citations

@misc{locatello2020objectcentric,
    title   = {Object-Centric Learning with Slot Attention},
    author  = {Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf},
    year    = {2020},
    eprint  = {2006.15055},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{Fan2024AdaptiveSA,
    title   = {Adaptive Slot Attention: Object Discovery with Dynamic Slot Number},
    author  = {Ke Fan and Zechen Bai and Tianjun Xiao and Tong He and Max Horn and Yanwei Fu and Francesco Locatello and Zheng Zhang},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2406.09196},
    url     = {https://api.semanticscholar.org/CorpusID:270440447}
}
@article{liu2025metaslot,
    title   = {MetaSlot: Break Through the Fixed Number of Slots in Object-Centric Learning},
    author  = {Liu, Hongjia and Zhao, Rongzhen and Chen, Haohan and Pajarinen, Joni},
    journal = {Advances in Neural Information Processing Systems},
    volume  = {38},
    pages   = {67319--67344},
    year    = {2026}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

slot_attention-1.5.1.tar.gz (6.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

slot_attention-1.5.1-py2.py3-none-any.whl (9.2 kB view details)

Uploaded Python 2Python 3

File details

Details for the file slot_attention-1.5.1.tar.gz.

File metadata

  • Download URL: slot_attention-1.5.1.tar.gz
  • Upload date:
  • Size: 6.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for slot_attention-1.5.1.tar.gz
Algorithm Hash digest
SHA256 d4d950c37bff4f583c801f8ada529db3eb97959764dc9df6a6de0213f22d43b9
MD5 4241f8ab796f42e36f622942011af741
BLAKE2b-256 096d74bd4466b9f7e8c6c6abcec5f0023eb08e45fd82fd454af564013d20f06b

See more details on using hashes here.

File details

Details for the file slot_attention-1.5.1-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for slot_attention-1.5.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 a63a0cce86302ca3d766f3213d401d85e35530d5da78dbad942c138853feb854
MD5 23a199d403c9efa8ca0b3807f74a4fe2
BLAKE2b-256 77802d2803153bfe407fadb9d3d6c525406b2d88e84b4b15fd111bd839ee341c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page