Skip to main content

The EmbeddingRWKV Model

Project description

The EmbeddingRWKV Model

https://github.com/howard-hou/EmbeddingRWKV

Tokenizer

The package ships with a byte-level trie tokenizer defined in rwkv_emb.reference.utils.TRIE_TOKENIZER, using the vocabulary file rwkv_emb/reference/rwkv_vocab_v20230424.txt. The tokenizer works on UTF-8 bytes and can be used to convert raw text into token IDs for the embedding model.

import os
import rwkv_emb.reference

#
reference_dir = os.path.dirname(os.path.abspath(rwkv_emb.reference.__file__))

#
vocab_path = os.path.join(reference_dir, "rwkv_vocab_v20230424.txt")
from rwkv_emb.reference.utils import TRIE_TOKENIZER

tokenizer = TRIE_TOKENIZER(vocab_path)

text = "hello world"
tokens = tokenizer.encode(text)

EOS_INDEX = 65535
tokens_with_eos = tokens + [EOS_INDEX]

The encode method returns a list of integers. For embedding inference, append the end-of-sequence token (65535) to mark completion before feeding the tokens to the model.

# !!! set these before import RWKV !!!
import os

os.environ["RWKV_CUDA_ON"] = '1'  # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries

from rwkv_emb.model import EmbeddingRWKV

EOS_INDEX = 65535

# download models: to be announced
model = EmbeddingRWKV(model_path='path-to-model')

# !!! model.forward(tokens, state) will modify state in-place !!!
# single-sample inference
emb, state = model.forward([187, 510, 1563, 310, 247, EOS_INDEX], None)
print(emb.detach().cpu().numpy())                   # get logits

# streaming a single sequence
emb, state = model.forward([187, 510], None)
emb, state = model.forward([1563], state)           # RNN has state (use deepcopy to clone states)
emb, state = model.forward([310, 247, EOS_INDEX], state)
print(emb.detach().cpu().numpy())                   # same result as above

# batch inference (all sequences must share the same length)
batch_tokens = [
    [187, 510, 1563, 310],
    [247, EOS_INDEX, 187, 310],
]
emb_batch, batch_state = model.forward(batch_tokens, None, full_output=False)
print(emb_batch.detach().cpu().numpy())             # shape: [batch, n_vocab]
print(len(batch_state), batch_state[-2].shape)      # batched state shapes
print('\n')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rwkv_emb-0.0.4.tar.gz (394.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rwkv_emb-0.0.4-py3-none-any.whl (393.8 kB view details)

Uploaded Python 3

File details

Details for the file rwkv_emb-0.0.4.tar.gz.

File metadata

  • Download URL: rwkv_emb-0.0.4.tar.gz
  • Upload date:
  • Size: 394.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for rwkv_emb-0.0.4.tar.gz
Algorithm Hash digest
SHA256 9536269d28f5fe0e7d4d803867bc7d1ea72f996dade7a08d7b45936474364a7b
MD5 e6d5fc242bf4c485921ebd01db59aeba
BLAKE2b-256 39b7e16ddb56db59c72f7d3664c85ba398af31a76d1ecd4845f6bb8883cd66c0

See more details on using hashes here.

File details

Details for the file rwkv_emb-0.0.4-py3-none-any.whl.

File metadata

  • Download URL: rwkv_emb-0.0.4-py3-none-any.whl
  • Upload date:
  • Size: 393.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for rwkv_emb-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 1ac485658000773df2c48f5d58e7c9be159c21a5f12fe6e8889a7fd47b340ccf
MD5 a59bfe1b5e5d8e9ea5e9cee986732b22
BLAKE2b-256 9f0621881cb68c05a6dfc816cb833007868d628de7645b6d59095df419b834ab

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page