Python wrapper for damo, a set of fast and robust hash functions.
Project description
Damo-Embedding
Quick Install
pip install damo-embedding
Example
DeepFM
import torch
import torch.nn as nn
from damo_embedding import Embedding
class DeepFM(torch.nn.Module):
def __init__(
self,
emb_size: int,
fea_size: int,
hid_dims=[256, 128],
num_classes=1,
dropout=[0.2, 0.2],
**kwargs,
):
super(DeepFM, self).__init__()
self.emb_size = emb_size
self.fea_size = fea_size
initializer = {
"name": "truncate_normal",
"mean": float(kwargs.get("mean", 0.0)),
"stddev": float(kwargs.get("stddev", 0.0001)),
}
optimizer = {
"name": "adam",
"gamma": float(kwargs.get("gamma", 0.001)),
"beta1": float(kwargs.get("beta1", 0.9)),
"beta2": float(kwargs.get("beta2", 0.999)),
"lambda": float(kwargs.get("lambda", 0.0)),
"epsilon": float(kwargs.get("epsilon", 1e-8)),
}
self.w = Embedding(
1,
initializer=initializer,
optimizer=optimizer,
)
self.v = Embedding(
self.emb_size,
initializer=initializer,
optimizer=optimizer,
)
self.w0 = torch.zeros(1, dtype=torch.float32, requires_grad=True)
self.dims = [fea_size * emb_size] + hid_dims
self.layers = nn.ModuleList()
for i in range(1, len(self.dims)):
self.layers.append(nn.Linear(self.dims[i - 1], self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(dropout[i - 1]))
self.layers.append(nn.Linear(self.dims[-1], num_classes))
self.sigmoid = nn.Sigmoid()
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""forward
Args:
input (torch.Tensor): input tensor
Returns:
tensor.Tensor: deepfm forward values
"""
assert input.shape[1] == self.fea_size
w = self.w.forward(input)
v = self.v.forward(input)
square_of_sum = torch.pow(torch.sum(v, dim=1), 2)
sum_of_square = torch.sum(v * v, dim=1)
fm_out = (
torch.sum((square_of_sum - sum_of_square)
* 0.5, dim=1, keepdim=True)
+ torch.sum(w, dim=1)
+ self.w0
)
dnn_out = torch.flatten(v, 1)
for layer in self.layers:
dnn_out = layer(dnn_out)
out = fm_out + dnn_out
out = self.sigmoid(out)
return out
Save Model
from damo_embedding import save_model
model = DeepFM(8, 39)
save_model(model, "./", training=False)
Document
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
damo-embedding-1.1.7.tar.gz
(222.5 kB
view hashes)
Built Distributions
Close
Hashes for damo_embedding-1.1.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f8076dc2dbbb0f540d1489a8f69e68853119ea37247fd19fac86a883e300bba5 |
|
MD5 | 9ea7e884095d84838e195e7106a1322b |
|
BLAKE2b-256 | 8617fa0405d6c6b7abf71d3be6f6921597eadc4df892c5d61b7632659d9de144 |
Close
Hashes for damo_embedding-1.1.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5848b39570eeea042c4ac765cc9eb5402c7e5d69f5220c2d66158e211cd1373f |
|
MD5 | 8fe6cd0d06d87bac173de46b5e97b175 |
|
BLAKE2b-256 | 68ced09d1dd1aa228958a6b3361f0de405ee0be87fc8cdbd3a6ff879fa948c23 |
Close
Hashes for damo_embedding-1.1.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 08841375c28ab42037359ccd5a6c0b6235b1f1aa0c4b538703043b93248bdde9 |
|
MD5 | 07474896f89de0e35a03ff4c46f92766 |
|
BLAKE2b-256 | f3d7a873fb1c51d0dc1b1b2462c61f4b1e8ac6dbc26815367cabd7c1d01e361b |
Close
Hashes for damo_embedding-1.1.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 60ef664c5b1118d80a31d02988a7d6ef4b40f214581815c5534339c40c39c76d |
|
MD5 | c4e9072a0e88f0550fd50409fc23d715 |
|
BLAKE2b-256 | 4620cb087ff45d8a5113710b8c3a879926438f6a8d84ec9f576721dc5c1b6307 |
Close
Hashes for damo_embedding-1.1.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0bf9372fbc88d1e419e888a395c2730e542ed40c566b783d03ddfdf97633ad88 |
|
MD5 | c3c42df6df047a5efb9bc595672dc2ae |
|
BLAKE2b-256 | 4f0303c3fadd4da9269f42d31f129a791d740817549a8b84fe89586eb5f2375b |
Close
Hashes for damo_embedding-1.1.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1b59e32796e8488352458d97b2e37f1c53eb967b17f5b326c0e157444d63b1b4 |
|
MD5 | 9850a665c90881e55a6b51b6035d2ef2 |
|
BLAKE2b-256 | 4212a38ee32c84d0cc7db67604f8239f272dd957c77c328f6a6ff8d8f8f8f941 |
Close
Hashes for damo_embedding-1.1.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 17d4f6059677c2aa79d1c02638ffdb28b3f79ac940ec194f88d931b4399c2a23 |
|
MD5 | 1fa7e015d4f2baecd8decadb22f2f4e3 |
|
BLAKE2b-256 | a0ce22e4dbffc24ee7848679eb63b2b99d911ddb1423cdfc1973a2442ef39991 |
Close
Hashes for damo_embedding-1.1.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c02d7a1488632a8fbed06bedb26f5689d73cf3d003eb5289b392fc52716374ae |
|
MD5 | 109c491d5a25c0b0d937450b3390f208 |
|
BLAKE2b-256 | 0b236bf99cfd1bbcb16d5dc868fb8fcd6e8dda3443c6a1bfb34bccf269ffbd5d |
Close
Hashes for damo_embedding-1.1.7-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e61289e3a0dfc50ee4720d7b6baca7bc56e739b8da7251aa1a97bc9356ec851e |
|
MD5 | 5e8fb8416d71d15c389737f7e92d5566 |
|
BLAKE2b-256 | 2eec292ec814ef85c2cd82f324c4737b8fb4b2e28a621a81daf879ea9308f636 |
Close
Hashes for damo_embedding-1.1.7-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | de83d1579d7550a4e97680152d3d395efa4a7c7aeb77ce9fe1a162f08be9283d |
|
MD5 | 85801287b94376e376f7ab6f658c43b2 |
|
BLAKE2b-256 | 7e99e3a42762c267420268071484a7edbbb409edcb8032f109bb337288d6f983 |