Python wrapper for damo, a set of fast and robust hash functions.
Project description
Damo-Embedding
Quick Install
pip install damo-embedding
Example
Embedding
import damo
import torch
import numpy as np
from typing import Union
from collections import defaultdict
class Storage(object):
"""singleton storage class."""
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = object.__new__(cls)
cls._instance.dir = kwargs.get("dir", "./embeddings")
cls._instance.ttl = kwargs.get("ttl", 8640000)
cls._instance.storage = damo.PyStorage(cls._instance.dir, cls._instance.ttl)
return cls._instance
@staticmethod
def checkpoint(path: str):
assert Storage._instance is not None
Storage._instance.storage.checkpoint(path)
@staticmethod
def dump(path: str):
assert Storage._instance is not None
Storage._instance.storage.dump(path)
@staticmethod
def load_from_checkpoint(path: str):
assert Storage._instance is not None
Storage._instance.storage.load_from_checkpoint(path)
class Embedding(torch.nn.Module):
_group = -1
def __init__(self, dim: int, initializer={}, optimizer={}, group=-1, **kwargs):
super(Embedding, self).__init__()
self.dim = dim
if group != -1:
self.group = group
assert 0 <= self.group < 256
else:
Embedding._group += 1
self.group = Embedding._group
assert 0 <= self.group < 256
self.storage = Storage(**kwargs).storage
# create initializer
init_params = damo.Parameters()
for k, v in initializer.items():
init_params.insert(k, v)
self.initializer = damo.PyInitializer(init_params)
# create optimizer
opt_params = damo.Parameters()
for k, v in optimizer.items():
opt_params.insert(k, v)
self.optimizer = damo.PyOptimizer(opt_params)
self.embedding = damo.PyEmbedding(
self.storage, self.optimizer, self.initializer, self.dim, self.group
)
def forward(self, inputs: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
"""embedding lookup
Args:
inputs (Union[torch.Tensor, np.ndarray]): input values
Returns:
torch.Tensor: embedding values (inputs.shape[0], inputs.shape[1], self.dim)
"""
data = inputs
if isinstance(inputs, torch.Tensor):
data = inputs.numpy().astype(np.uint64)
elif isinstance(inputs, np.ndarray):
if data.type != np.uint64:
data = inputs.astype(np.uint64)
batch_size, width = data.shape
keys = np.unique(np.concatenate(data)).astype(np.uint64)
length = keys.shape[0]
weights = np.zeros(length * self.dim, dtype=np.float32)
self.embedding.lookup(keys, weights)
weights = weights.reshape((length, self.dim))
weight_dict = {k: v for k, v in zip(keys, weights)}
values = np.zeros(shape=(batch_size, width, self.dim), dtype=np.float32)
for i in range(batch_size):
for j in range(width):
key = data[i][j]
# 0 is padding value
if key != 0:
values[i][j] = weight_dict[key]
def apply_gradients(gradients):
grad = gradients.numpy()
grad = grad.reshape((batch_size, width, self.dim))
grad_dict = defaultdict(lambda: np.zeros(self.dim, dtype=np.float32))
for i in range(batch_size):
for j in range(width):
key = data[i][j]
if key != 0:
grad_dict[key] += grad[i][j]
values = np.zeros(length * self.dim, dtype=np.float32)
for i in range(length):
values[i * self.dim : (i + 1) * self.dim] = (
grad_dict[keys[i]] / batch_size
)
self.embedding.apply_gradients(keys, values)
ret = torch.from_numpy(values)
ret.requires_grad_()
ret.register_hook(apply_gradients)
return ret
DeepFM
import torch
import torch.nn as nn
import numpy as np
from typing import Union
from embedding import Embedding
class DeepFM(torch.nn.Module):
def __init__(
self,
emb_size: int,
fea_size: int,
hid_dims=[256, 128],
num_classes=1,
dropout=[0.2, 0.2],
**kwargs,
):
super(DeepFM, self).__init__()
self.emb_size = emb_size
self.fea_size = fea_size
initializer = {
"name": "truncate_normal",
"mean": float(kwargs.get("mean", 0.0)),
"stddev": float(kwargs.get("stddev", 0.0001)),
}
optimizer = {
"name": "adam",
"gamma": float(kwargs.get("gamma", 0.001)),
"beta1": float(kwargs.get("beta1", 0.9)),
"beta2": float(kwargs.get("beta2", 0.999)),
"lambda": float(kwargs.get("lambda", 0.0)),
"epsilon": float(kwargs.get("epsilon", 1e-8)),
}
self.w = Embedding(
1,
initializer=initializer,
optimizer=optimizer,
group=0,
**kwargs,
)
self.v = Embedding(
self.emb_size,
initializer=initializer,
optimizer=optimizer,
group=1,
**kwargs,
)
self.w0 = torch.zeros(1, dtype=torch.float32, requires_grad=True)
self.dims = [fea_size * emb_size] + hid_dims
self.layers = nn.ModuleList()
for i in range(1, len(self.dims)):
self.layers.append(nn.Linear(self.dims[i - 1], self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(dropout[i - 1]))
self.layers.append(nn.Linear(self.dims[-1], num_classes))
self.sigmoid = nn.Sigmoid()
def forward(self, inputs: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
"""forward
Args:
inputs (Union[torch.Tensor, np.ndarray]): input tensor
Returns:
tensor.Tensor: deepfm forward values
"""
assert inputs.shape[1] == self.fea_size
w = self.w.forward(inputs)
v = self.v.forward(inputs)
square_of_sum = torch.pow(torch.sum(v, dim=1), 2)
sum_of_square = torch.sum(v * v, dim=1)
fm_out = (
torch.sum((square_of_sum - sum_of_square) * 0.5, dim=1, keepdim=True)
+ torch.sum(w, dim=1)
+ self.w0
)
dnn_out = torch.flatten(v, 1)
for layer in self.layers:
dnn_out = layer(dnn_out)
out = fm_out + dnn_out
out = self.sigmoid(out)
return out
Document
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
damo-embedding-1.0.5.tar.gz
(86.9 kB
view hashes)
Built Distributions
Close
Hashes for damo_embedding-1.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e084f1cf85b05ed3e1d14010185bb1e3e499238dd70b0010b10525af61e3f724 |
|
MD5 | a3d440cbc5c146e91cb8dc4aa1ce0ef0 |
|
BLAKE2b-256 | b0d8a3334179cee2ac48dec544f6263d99e97fa192743f87306712c53000eb5c |
Close
Hashes for damo_embedding-1.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a2197f0582d84aeedf01896f2ac21d9c014fce20110d995850f895770fe2a4dd |
|
MD5 | cb369474ec90d3eb86083dfc20ff7b84 |
|
BLAKE2b-256 | a11a79fb0606d64b5394bf7f89900d8f5d7cb2ee61d68a8280ee0f727a5081f2 |
Close
Hashes for damo_embedding-1.0.5-cp310-cp310-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a1f7d11cc459ef64d349c888cbbcaec799a9bcad5a283a6428c2c5ecbd77617b |
|
MD5 | f60450947570791bb8867b657f2d0136 |
|
BLAKE2b-256 | 707c9313f16fadbeae6ce6fe49549cdd9d8e76cc36db0b6b9bdfcaa710a859e4 |
Close
Hashes for damo_embedding-1.0.5-cp310-cp310-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf268f27c05184c691f9d2f685d7095b5657f987939c77b1a12479a4fcba9de4 |
|
MD5 | 4e637759f18e3a41d485b0136c1b41d4 |
|
BLAKE2b-256 | 8f89b7580c704454271805cdaa96b8e837e27c1a9b58f5e98cdbbe89bc40970e |
Close
Hashes for damo_embedding-1.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a0ef54495638e4db48d29f122ca8b2b4710225664da404ac683703b5a14a6cfb |
|
MD5 | f18a6b694642d23019240a279e2fa49c |
|
BLAKE2b-256 | 71f24b62d45831c1401b6a0e4b28964896a168268e068b1886075d365202e88c |
Close
Hashes for damo_embedding-1.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 68db04122c73b45fc72632688b28035d0ead12aa49aa4b6c34e02d5145a2dab8 |
|
MD5 | d7a834e95d4601e7bf4498a45bf11ee9 |
|
BLAKE2b-256 | 55029592c6e837d01c70f2c1b465f704a65c7c4353252bdfb081bfe59c547ae2 |
Close
Hashes for damo_embedding-1.0.5-cp39-cp39-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93e89d5245440b58f24606fb53921d71585b93d9df69f7734fe9ec9e4d541a0d |
|
MD5 | 6c451485400f65555889279da8eb5d37 |
|
BLAKE2b-256 | ffb1b16e20188ee6473e1ab0446e29f8fd1d9526e6d754e7520833a2cd12243b |
Close
Hashes for damo_embedding-1.0.5-cp39-cp39-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fdeb238afa17111f19bb629f38405ab1c88118ef2e598aa24bead39f13bebe8d |
|
MD5 | 2b428d6a1f9a9ce70a6d2fb5d3215819 |
|
BLAKE2b-256 | e28b53865e4b7a2762a3d325c216b4b20846b5af5a71ba5e287ab84d3edd1937 |
Close
Hashes for damo_embedding-1.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b1595a4594f320686f668bdf88162c1a66408ec9ede71457878b7f946596db81 |
|
MD5 | 2fabc816d0b195f5c8babae9f2d5370b |
|
BLAKE2b-256 | 1267eacc51f77d5ae25b52178a094a0d0de9a427775ee0b4b54f62dcf589e444 |
Close
Hashes for damo_embedding-1.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e5cd11865cd46afd3a939d3da400078a1068acbb6999ce993813aa5bd4c4a7dd |
|
MD5 | 37f19df9ce982e8e82bdf82ac4c6d526 |
|
BLAKE2b-256 | 70e647ed94b8fbb904616aa1ef6d25d483b4206111d787fd1a44f107fb5e3b18 |
Close
Hashes for damo_embedding-1.0.5-cp38-cp38-macosx_11_0_arm64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5c355befce4a2bd4e85165e90312defaf47867a739a2d431a20212142bb8e2a9 |
|
MD5 | 1bf2fa88e9d23ce690203669058ef363 |
|
BLAKE2b-256 | 3b262c6b23478b7fc73438bf8ba5aee6f5be665b331e30eedba5ff4083e8a376 |
Close
Hashes for damo_embedding-1.0.5-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0562512182d6f1cc32fe38495ab434085bad43f3c912ea67d8cdd837b916d55c |
|
MD5 | 70bee724937b6cc29479d652b5f371a3 |
|
BLAKE2b-256 | cd027556bb1b94015ddeb217b46b55b9a0af7bc90798a8c7204d1244824f331b |
Close
Hashes for damo_embedding-1.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6a8010f99cf00aeae399308de649d9e499874e4f6f8f36d4fec99afbced30a4f |
|
MD5 | 29d632a9f79a520b567e77278d0b16e2 |
|
BLAKE2b-256 | bc3a06f26bbaab91441ca3d8de534bb8aa591bdc2998cd739b888794d7086d2b |
Close
Hashes for damo_embedding-1.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 19a97a29ffca41d4975efc9edac9d417384c35e6ba1f9522435fd297966c362e |
|
MD5 | 9d76813d9aa8f7dd87fa526023946f65 |
|
BLAKE2b-256 | 1d6dea74278375d9853b09edd59d560383242edc8ba2942a72d801bc00604fef |
Close
Hashes for damo_embedding-1.0.5-cp37-cp37m-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | abae9f5001488dae75e96a0a9f1e600253b041c606632844cb4dafeb36f68690 |
|
MD5 | b649ddd7510b4f01d12427dcd5f9d468 |
|
BLAKE2b-256 | bfd99193dc642b72daf6dab9f6c6968ede93044aac7ae3363a17150685733d27 |
Close
Hashes for damo_embedding-1.0.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 816f8435b13e4aac528ba7850254bd0a2b3baf004355194b8ec85c0f9df3bab2 |
|
MD5 | be9a61c5e15f84836714b356d98bf5b9 |
|
BLAKE2b-256 | 04e5f5c71e9aa329ca7b8cf14ca9526717138f941539481bee38a6f450f735c6 |
Close
Hashes for damo_embedding-1.0.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8fec83fea47c7e17c3299255948615e7c166605fc605753b7e0e5c192ad52d84 |
|
MD5 | 79c9d95decd453ea765fa3718f211faa |
|
BLAKE2b-256 | 35627afee28cfd049366efe57cc28e0ca195b03a5a965777dacb746a7e45d0cc |
Close
Hashes for damo_embedding-1.0.5-cp36-cp36m-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 270f0e7d4cf8d1375e9726e1082e1d1dc39f8289402bae3ea862cf6c919e3b91 |
|
MD5 | 053645dfa67d2ae07f81ecf4f4792ee1 |
|
BLAKE2b-256 | 4bc09e39dfee71dacd06e98873c84b866cefa248bdc9a96653caf285c9b9f715 |