Python wrapper for damo, a set of fast and robust hash functions.
Project description
Damo-Embedding
Quick Install
pip install damo-embedding
Example
Embedding
import damo
import torch
import numpy as np
from typing import Union
from collections import defaultdict
class Storage(object):
"""singleton storage class."""
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = object.__new__(cls)
cls._instance.dir = kwargs.get("dir", "./embeddings")
cls._instance.ttl = kwargs.get("ttl", 8640000)
cls._instance.storage = damo.PyStorage(cls._instance.dir, cls._instance.ttl)
return cls._instance
@staticmethod
def checkpoint(path: str):
assert Storage._instance is not None
Storage._instance.storage.checkpoint(path)
@staticmethod
def dump(path: str):
assert Storage._instance is not None
Storage._instance.storage.dump(path)
@staticmethod
def load_from_checkpoint(path: str):
assert Storage._instance is not None
Storage._instance.storage.load_from_checkpoint(path)
class Embedding(torch.nn.Module):
_group = -1
def __init__(self, dim: int, initializer={}, optimizer={}, group=-1, **kwargs):
super(Embedding, self).__init__()
self.dim = dim
if group != -1:
self.group = group
assert 0 <= self.group < 256
else:
Embedding._group += 1
self.group = Embedding._group
assert 0 <= self.group < 256
self.storage = Storage(**kwargs).storage
# create initializer
init_params = damo.Parameters()
for k, v in initializer.items():
init_params.insert(k, v)
self.initializer = damo.PyInitializer(init_params)
# create optimizer
opt_params = damo.Parameters()
for k, v in optimizer.items():
opt_params.insert(k, v)
self.optimizer = damo.PyOptimizer(opt_params)
self.embedding = damo.PyEmbedding(
self.storage, self.optimizer, self.initializer, self.dim, self.group
)
def forward(self, inputs: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
"""embedding lookup
Args:
inputs (Union[torch.Tensor, np.ndarray]): input values
Returns:
torch.Tensor: embedding values (inputs.shape[0], inputs.shape[1], self.dim)
"""
data = inputs
if isinstance(inputs, torch.Tensor):
data = inputs.numpy().astype(np.uint64)
elif isinstance(inputs, np.ndarray):
if data.type != np.uint64:
data = inputs.astype(np.uint64)
batch_size, width = data.shape
keys = np.unique(np.concatenate(data)).astype(np.uint64)
length = keys.shape[0]
weights = np.zeros(length * self.dim, dtype=np.float32)
self.embedding.lookup(keys, weights)
weights = weights.reshape((length, self.dim))
weight_dict = {k: v for k, v in zip(keys, weights)}
values = np.zeros(shape=(batch_size, width, self.dim), dtype=np.float32)
for i in range(batch_size):
for j in range(width):
key = data[i][j]
# 0 is padding value
if key != 0:
values[i][j] = weight_dict[key]
def apply_gradients(gradients):
grad = gradients.numpy()
grad = grad.reshape((batch_size, width, self.dim))
grad_dict = defaultdict(lambda: np.zeros(self.dim, dtype=np.float32))
for i in range(batch_size):
for j in range(width):
key = data[i][j]
if key != 0:
grad_dict[key] += grad[i][j]
values = np.zeros(length * self.dim, dtype=np.float32)
for i in range(length):
values[i * self.dim : (i + 1) * self.dim] = (
grad_dict[keys[i]] / batch_size
)
self.embedding.apply_gradients(keys, values)
ret = torch.from_numpy(values)
ret.requires_grad_()
ret.register_hook(apply_gradients)
return ret
DeepFM
import torch
import torch.nn as nn
import numpy as np
from typing import Union
from embedding import Embedding
class DeepFM(torch.nn.Module):
def __init__(
self,
emb_size: int,
fea_size: int,
hid_dims=[256, 128],
num_classes=1,
dropout=[0.2, 0.2],
**kwargs,
):
super(DeepFM, self).__init__()
self.emb_size = emb_size
self.fea_size = fea_size
initializer = {
"name": "truncate_normal",
"mean": float(kwargs.get("mean", 0.0)),
"stddev": float(kwargs.get("stddev", 0.0001)),
}
optimizer = {
"name": "adam",
"gamma": float(kwargs.get("gamma", 0.001)),
"beta1": float(kwargs.get("beta1", 0.9)),
"beta2": float(kwargs.get("beta2", 0.999)),
"lambda": float(kwargs.get("lambda", 0.0)),
"epsilon": float(kwargs.get("epsilon", 1e-8)),
}
self.w = Embedding(
1,
initializer=initializer,
optimizer=optimizer,
group=0,
**kwargs,
)
self.v = Embedding(
self.emb_size,
initializer=initializer,
optimizer=optimizer,
group=1,
**kwargs,
)
self.w0 = torch.zeros(1, dtype=torch.float32, requires_grad=True)
self.dims = [fea_size * emb_size] + hid_dims
self.layers = nn.ModuleList()
for i in range(1, len(self.dims)):
self.layers.append(nn.Linear(self.dims[i - 1], self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(dropout[i - 1]))
self.layers.append(nn.Linear(self.dims[-1], num_classes))
self.sigmoid = nn.Sigmoid()
def forward(self, inputs: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
"""forward
Args:
inputs (Union[torch.Tensor, np.ndarray]): input tensor
Returns:
tensor.Tensor: deepfm forward values
"""
assert inputs.shape[1] == self.fea_size
w = self.w.forward(inputs)
v = self.v.forward(inputs)
square_of_sum = torch.pow(torch.sum(v, dim=1), 2)
sum_of_square = torch.sum(v * v, dim=1)
fm_out = (
torch.sum((square_of_sum - sum_of_square) * 0.5, dim=1, keepdim=True)
+ torch.sum(w, dim=1)
+ self.w0
)
dnn_out = torch.flatten(v, 1)
for layer in self.layers:
dnn_out = layer(dnn_out)
out = fm_out + dnn_out
out = self.sigmoid(out)
return out
Document
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
damo-embedding-1.0.4.tar.gz
(86.7 kB
view hashes)
Built Distributions
Close
Hashes for damo_embedding-1.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e9ceaa767e12341a1ce255bf502c36335c3bead3dd3b470dc7810ea866849a2 |
|
MD5 | 0d796ff21bf087166cf4fd6415a2c0bb |
|
BLAKE2b-256 | f34aabe0eb05ce355666f77a7fe3e19f1831a70d5e0a0dde70e79462163800ec |
Close
Hashes for damo_embedding-1.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7b8923bace372220d0808d201473f523eba0186a792f8d6fef3d065651f1fb90 |
|
MD5 | 4d56b6814654a2f65d283d7994ee5688 |
|
BLAKE2b-256 | f86e9d55cc2c0aa7fdf525f219509e1c3be40268cddf00d89079096f3c90243c |
Close
Hashes for damo_embedding-1.0.4-cp310-cp310-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7ce23ed7310e101e36bb9c0228140ece3e4ee72eb9d8b180745359e426acb28b |
|
MD5 | 04a5413c8eedb76494eb75ee36f63b4b |
|
BLAKE2b-256 | e74c7649f5bd7f1baa40484bca836efcb8418d6dd36435ffa556b4a080840622 |
Close
Hashes for damo_embedding-1.0.4-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0b2792edce7090b1aba607f2a9f12f6f87184d257f5a232b5ce991b214ab9d26 |
|
MD5 | 35f654f8fff0f11ce017b66a273050b9 |
|
BLAKE2b-256 | 6b86efff0a490d2107becb6add6d8647a03d3cb9a17b5daf49a751b8691b593f |
Close
Hashes for damo_embedding-1.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8ec4df89fa9436095cf1614d674e5db43b18ccb965202d243a19067125b013ae |
|
MD5 | 8ed78b1b38856b128f1a37bc73f147fb |
|
BLAKE2b-256 | b73971de6c716fdfb287190d08aff4c0d2fda8fa8f16c027eaf9c7270332e210 |
Close
Hashes for damo_embedding-1.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 66a733e3e67e7d6b1ce54d1f117aa56ac7da5ceaaea18f70a2ed3ad318e1bab7 |
|
MD5 | ea58c3ff535461c0707bd01c43c5fc77 |
|
BLAKE2b-256 | 989563c14b51b2a47ca1b6f3ede22701db33854fe22999b683d962cc571e2dd6 |
Close
Hashes for damo_embedding-1.0.4-cp39-cp39-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0a8f551376f40625bbbd0d12dbf60246c927de505edc74b39e2f59c705613a81 |
|
MD5 | 9b83f397d52a9451fb68d349f6f33b44 |
|
BLAKE2b-256 | e2ecfbbb77446c85b25204866d0699eb451fb93f274ce93795a3bccc1a7c9baa |
Close
Hashes for damo_embedding-1.0.4-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4276100919ac1692ec258c8861e835956abc33bc2ae6bf1d8707e125457e0ea0 |
|
MD5 | 1657adf5ec4bfd2558bd2dbfa2372fb3 |
|
BLAKE2b-256 | 700762617d9db03db782fed74c02523e206f9e620db96c16426e02df3c68576f |
Close
Hashes for damo_embedding-1.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 71fb99179844a21854f33f98edcfde9a02f73e59a478b8d142155584b5d9a871 |
|
MD5 | 3bb2cc03ea1778e67b390ddedd7e55e1 |
|
BLAKE2b-256 | e396a2df4329d8a1091d8226410b7cc5fd7c290701c51ac5d99b4bf5984ba5a9 |
Close
Hashes for damo_embedding-1.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 560b64bd8808bfbab96b8ade57c029e359c9d96f1d72d06232a8dee29a1bdfed |
|
MD5 | a35eaaddea800deabfe1272499f7a775 |
|
BLAKE2b-256 | c409e2731b5532be00d2cfcfa8576281da7fcf3a5844b4f1eeec201007864236 |
Close
Hashes for damo_embedding-1.0.4-cp38-cp38-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e7d75a649eeb469320aef53c9caf2c233a18b9311b8e2890af12a84db5c832e0 |
|
MD5 | 7a0ea57aa8b6ec16fe86efd495e97adf |
|
BLAKE2b-256 | f94f9217083b36eb57a2c0331b015f055c2c79a52e06ad25393980470b1ddb4b |
Close
Hashes for damo_embedding-1.0.4-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 519e4a592c8ce63b29a68aecc771436f842e6e35db1c2c7d795adacd70b2d867 |
|
MD5 | 3d69d11510e9d65dfdffec5e2a6d25ff |
|
BLAKE2b-256 | b3fb3de7aafbfdaef6d0a0568a069c1c5d043777d9ff395824ec8ea3121319bc |
Close
Hashes for damo_embedding-1.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f2adbc5afef5008ad21d2e0447032ab452cac4b9957ae84c86de54ded3f094da |
|
MD5 | 3ebe3e00b20bb0275de90e40dd6378ac |
|
BLAKE2b-256 | 6946fddd8ea84230e711dac1be82791b407a610682d55b63c84060494b87a640 |
Close
Hashes for damo_embedding-1.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f4ab5aec702aa19e79cb111af6fd4f3806f29d4b4aed06b3ccdc01d783d472d2 |
|
MD5 | 7b6bc8eb463168d72f567f5198b2ef62 |
|
BLAKE2b-256 | 8cc98cf0c081893db51b526704522679a9ef05159efc000d7fa3994a8a58f0fe |
Close
Hashes for damo_embedding-1.0.4-cp37-cp37m-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3a5f3f26127e932f350ef35bfee932ae3e164a0b5f0ddfd9cfa9f1a2b419a99b |
|
MD5 | 79ad0ec11cf62e6d256d5536e7c9ce89 |
|
BLAKE2b-256 | e2576c298a3e8472ca44edc63ff21a95575a315f9ee1512148e3f4919ea1d01a |
Close
Hashes for damo_embedding-1.0.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 850dbef593b407bd3df21ad04f7a46bd6399331a0066b25461349579538c15f6 |
|
MD5 | 6958480ae4c6d12a394facffabdf6cb9 |
|
BLAKE2b-256 | d8c4474bc6a14472a89d006d73b191124e568e98ab1dc085d91f0fed01c0f9a3 |
Close
Hashes for damo_embedding-1.0.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3c3f81fa9b1ed7bc06182080490ee7e00c8dc8a416a7a15f45f6a2720df171f5 |
|
MD5 | d6ab030dac62b913b9b51fd844a44e03 |
|
BLAKE2b-256 | 6b4b23061b95c2b041dfd2332c98c253561b0dca2c8965d141fd09614db22d04 |
Close
Hashes for damo_embedding-1.0.4-cp36-cp36m-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | dc26640cea3f025a477659425aec84fe5e0a9c5822537201a127685e00f7c8ec |
|
MD5 | 71e884171ec6c56f78e8b70209b2abc4 |
|
BLAKE2b-256 | 36b5ffcdc9e12e566e6bff92c3bc8a46c040449e4ac429aa6a77fc229f2a2726 |