Skip to main content

A concept of multidimensional indexing for tensors

Project description

eindex animated logo

eindex

Concept of multidimensional indexing for tensors

Example of K-means clustering

Plain numpy

def kmeans(init_centers, X, n_iterations: int):
    n_clusters, n_dim = init_centers.shape
    n_onservations, n_dim = X.shape

    centers = init_centers.copy()
    for _ in range(n_iterations):
        d = cdist(centers, X)
        clusters = np.argmin(d, axis=0)
        new_centers_sum = np.zeros_like(centers)
        indices_dim = np.arange(n_dim)[None, :]
        np.add.at(new_centers_sum, (clusters[:, None], indices_dim), X)
        cluster_counts = np.bincount(clusters, minlength=n_clusters)
        centers = new_centers_sum / cluster_counts[:, None]
    return centers

With eindex

def kmeans_eindex(init_centers, X, n_iterations: int):
    centers = init_centers
    for _ in range(n_iterations):
        d = cdist(centers, X)
        clusters = EX.argmin(d, 'cluster i -> [cluster] i')
        centers = EX.scatter(X, clusters, 'i c, [cluster] i -> cluster c',  
                             agg='mean', cluster=len(centers))
    return centers

Tutorial notebook

Goals

  • Form helpful 'language' to think about indexing and index-related operations. Tools shape minds
  • Cover most common cases of multidimensional indexing that are hard to implement using the standard API
  • Approach should be applicable to most common tensor frameworks, autograd should work out-of-the-box
  • Aim for readable and reliable code
  • Allow simple adoption in existing codebases
  • Implementation should base on fairly common tensor operations. No custom kernels allowed.
  • Complexity should be visible: execution plan for every operation should form a static graph. Structure of the graph depends on the pattern, but not on tensor arguments.

Non-goals: there is no goal to develop 'the shortest notation' or 'the most advanced/comprehensive tool for indexing' or 'cover as many operations as possible' or 'completely replace default indexing'.

Examples

Follow tutorial first to learn about all operations provided.

Click to unfold

- how do I select a single embedding from every image in a batch?

Let's say you have pairs of images and captions, and you want to take closest embedding:

score = einsum(images_bhwc, sentences_btc, 'b h w c, b token c -> b h w token')
closest_index = argmax(score, 'b h w token -> [h, w] b token')
closest_emb = gather('b h w c, [h, w] b token -> b t c', images_bhwc, closest_index)

To adjust this example for video not image, replace 'h w' to 'h w t'. Yes, that simple.

- how to collect top-1 or top-3 predicted word for every position in audio/text?

[most_likely_words] = argmax(prob_tbc, 't b w -> [w] t b')
[top_words] = argsort(prob_tbc, 't b w -> [w] t b order')[..., -3:]

- how to average embeddings over neighbors in a graph?

# without batch (single graph)
gather('vin c, [vin, vout] edge -> vout', embeddings, edges)
# with batch (multile graphs)
gather('b vin c, [b, vin, vout] edge -> b vout', embeddings, edges)

- can eindex help with (complex) positional embeddings?

If we're speaking about trainable abspos, it can be just saved as emb_hwc and added every time to a batch. There is no need for indexing.

But it can be very helpful for complex scenarios: for example in T5-relpos, when a bias is added to every logit before softmax-ing to compute attention? That's simple to implement for 1d, and much harder for 2d/3d. Let's implement for 2d with eindex:

N = None
pos # [I, J] i j
pos1 = pos[:, :, :, N, N]
pos2 = pos[:, N, N, :, :]
xy_diff = (pos1 - pos2) % image_side  # we make shifts positive by wrapping
attention_bias = gather('i j head , [i, j] i1 j1 i2 j2 -> i1 j1 i2 j2 head', biases, xy_diff)

Note that we indeed encounter relative position (shift in x and y), which is not done in most implementations that deal with flat sequence instead.

In a similar way we could produce vector-shift attention (another typical version of relpos):

vector_shift = gather('i j head c, [i, j] i1 j1 i2 j2 -> i1 j1 i2 j2 head c', biases, xy_diff)

Implementation

Repo provides two implementation:

  • array api standard. This implementation is based on a standard that multiple frameworks pre-agreed to follow. Implementation uses only API from standard, so all available operations support all frameworks that follow the standard.

    At some point this should become the one and the only implementation.

    Here is the catch: current support of array api standard is poor, that's why the second implementation exists

  • numpy implementation

    This independent implementation works right now.

    Numpy implementation is great to test things out, and is handy for a number of non-DL applications as well.

Development Status

API looks solid, but breaking changes are still possible, so lock the version in your projects (e.g. eindex==0.1.0)

Related projects

Other projects you likely want to look at:

  • tullio by Michael Abbott (@mcabbott) provides Julia macros with a high level of flexibility. Resulting operations are then compiled.
  • torchdim by Zachary DeVito (@zdevito) introduces "dimension objects", which in particular allow convenient multi-dim indexing
  • einindex is an einops-inspired prototype by Jonathan Malmaud (@malmaud) to develop multi-dim indexing notation. (Also, that's why this package isn't called einindex)

Contributing

We welcome the following contributions:

  • next time you deal with multidimensional indexing, do this with eindex
    Worked? → great - let us know; didn't work or unclear how to implement → post in discussions
  • if you feel you're already fluent in eindex, help others
  • guides/tutorials/video-guides are very welcome, and will be linked
  • If you want to translate tutorial to other language and post it somewhere - welcome

Discussions

Use github discussions for this project https://github.com/arogozhnikov/eindex/discussions

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

eindex-0.1.0.tar.gz (34.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

eindex-0.1.0-py3-none-any.whl (13.6 kB view details)

Uploaded Python 3

File details

Details for the file eindex-0.1.0.tar.gz.

File metadata

  • Download URL: eindex-0.1.0.tar.gz
  • Upload date:
  • Size: 34.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.23.3

File hashes

Hashes for eindex-0.1.0.tar.gz
Algorithm Hash digest
SHA256 f7225c5e478663931a2ccf25adfff932fd80c942e7c507ca54b709f1b1524954
MD5 5358eb85a2c73acac70e6dff114389a8
BLAKE2b-256 d5f8faa6a4a97d1e10ee3b4494ce764aa82daa5916adf562ef5155ca5c632d26

See more details on using hashes here.

File details

Details for the file eindex-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: eindex-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 13.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: python-httpx/0.23.3

File hashes

Hashes for eindex-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a806489c30bc2f6ae61417763bc062780d3816b69716c6a7e2419d9c6c80c7c1
MD5 042991f96e995bb1e9e4f542fb7165a2
BLAKE2b-256 14d0a8307baae963d97aea356f8573108b74b60b123ae8caa25efffd8bf607fa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page