Python wrapper for damo, a set of fast and robust hash functions.
Project description
Damo-Embedding
Quick Install
pip install damo-embedding
Example
DeepFM
import torch
import torch.nn as nn
from damo_embedding import Embedding
class DeepFM(torch.nn.Module):
def __init__(
self,
emb_size: int,
fea_size: int,
hid_dims=[256, 128],
num_classes=1,
dropout=[0.2, 0.2],
**kwargs,
):
super(DeepFM, self).__init__()
self.emb_size = emb_size
self.fea_size = fea_size
initializer = {
"name": "truncate_normal",
"mean": float(kwargs.get("mean", 0.0)),
"stddev": float(kwargs.get("stddev", 0.0001)),
}
optimizer = {
"name": "adam",
"gamma": float(kwargs.get("gamma", 0.001)),
"beta1": float(kwargs.get("beta1", 0.9)),
"beta2": float(kwargs.get("beta2", 0.999)),
"lambda": float(kwargs.get("lambda", 0.0)),
"epsilon": float(kwargs.get("epsilon", 1e-8)),
}
self.w = Embedding(
1,
initializer=initializer,
optimizer=optimizer,
)
self.v = Embedding(
self.emb_size,
initializer=initializer,
optimizer=optimizer,
)
self.w0 = torch.zeros(1, dtype=torch.float32, requires_grad=True)
self.dims = [fea_size * emb_size] + hid_dims
self.layers = nn.ModuleList()
for i in range(1, len(self.dims)):
self.layers.append(nn.Linear(self.dims[i - 1], self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(dropout[i - 1]))
self.layers.append(nn.Linear(self.dims[-1], num_classes))
self.sigmoid = nn.Sigmoid()
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""forward
Args:
input (torch.Tensor): input tensor
Returns:
tensor.Tensor: deepfm forward values
"""
assert input.shape[1] == self.fea_size
w = self.w.forward(input)
v = self.v.forward(input)
square_of_sum = torch.pow(torch.sum(v, dim=1), 2)
sum_of_square = torch.sum(v * v, dim=1)
fm_out = (
torch.sum((square_of_sum - sum_of_square)
* 0.5, dim=1, keepdim=True)
+ torch.sum(w, dim=1)
+ self.w0
)
dnn_out = torch.flatten(v, 1)
for layer in self.layers:
dnn_out = layer(dnn_out)
out = fm_out + dnn_out
out = self.sigmoid(out)
return out
Save Model
from damo_embedding import save_model
model = DeepFM(8, 39)
save_model(model, "./", training=False)
Document
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
damo-embedding-1.1.4.tar.gz
(222.5 kB
view hashes)
Built Distributions
Close
Hashes for damo_embedding-1.1.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c3da17fe0e3db12bf6583b96ea59ec749891d3371a76d0aa8bd40dae260a740f |
|
MD5 | 1ef52590edcd2f2e559530562ec85339 |
|
BLAKE2b-256 | fda9e3e9eab89fe77ffb931a933001ca6af586df9e5ff840176a26c59cc7bd57 |
Close
Hashes for damo_embedding-1.1.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 17ebb02e4eb589b67c586e3e4bc0d36c3a21434b8d9124d5acb8b09792fb74fa |
|
MD5 | 4f7d3ca13b61f782681be2613140170c |
|
BLAKE2b-256 | 09048f4ac927a1ac1689f0addd5a72b828db4bdf683dbea087eb6a1d56c9abc9 |
Close
Hashes for damo_embedding-1.1.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7ff2e58411dda139402f3e4fb10c8084c35f6b584b042c2ac7cf2a0b9543d8e7 |
|
MD5 | 01dfb60fa0627c6610a1f7bf5137d329 |
|
BLAKE2b-256 | ad966169b033c74fff705fc355d57bde314fbfad17242cbeb9a4c93d59684166 |
Close
Hashes for damo_embedding-1.1.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b064ee7abae8282eb41249c02a79b53490cee18bfd4cbe44829b33c92d91654a |
|
MD5 | 4295cee38ac7a4c898e53c3bd9fa3e1e |
|
BLAKE2b-256 | dc0d5a28e131a8e713ee7ce79f32b5bab1545b4adb4cb3b88a405f0031b422b3 |
Close
Hashes for damo_embedding-1.1.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a05cdb708e37041a1599a52a480bd1d4e7124af0e7e7d59a82d39ee8793c5e7c |
|
MD5 | 781a65459a5db7c6d5820116e4e9c8d8 |
|
BLAKE2b-256 | 6d441257087a09c5c44b80a71aa0ea83decc3b97620fcf047eb9358e1beda151 |
Close
Hashes for damo_embedding-1.1.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2afdef882bdb09b093df0f17742e2de720dfed3f7244a47d82caa3c4ef38dc03 |
|
MD5 | eb73fda8d16468b8d6b0a9b3aeb69e20 |
|
BLAKE2b-256 | be792bc8426a804694199708ffbce2cbae74879d8bd27b8785dc51cefb7a45e5 |
Close
Hashes for damo_embedding-1.1.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 60e1a6cf79d3c3969007eca53141f1caca0741bdf958e08ba7d962a2be77b93b |
|
MD5 | b48f87fe3d453ffed91bb810b21c9b6a |
|
BLAKE2b-256 | 89f0d55f3925a4ce19d031dfc7496d1745e0e4377d001ca1b0d73afd8b76c5c6 |
Close
Hashes for damo_embedding-1.1.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 09921ad8e454bc00b5de4a7129c65441df14f3dc8ed52f4d5e7e6f7014383c80 |
|
MD5 | eb0fd2c5fb5e02ae738ea797db034787 |
|
BLAKE2b-256 | a2b93e7069c670134f655b6fcfdb101f86e000c38b610b1d61b3dfa4e68022de |
Close
Hashes for damo_embedding-1.1.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ea2bbdca690c3a0beac5ea8c6dcc94106a4324075086e0a10e5007d1d5dcace9 |
|
MD5 | f676de7106f40047390e1576fa8fb5bd |
|
BLAKE2b-256 | 84d21c1f5e151e310b7b6befbd626b644c4cb4160fc443d6b6735a573c19e335 |
Close
Hashes for damo_embedding-1.1.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | defcb9a1c855f5734242300ad527f5c57f0344442f477d061b0e8b29a0e3b893 |
|
MD5 | 2b6a68852aa03423e9ec465307084ed2 |
|
BLAKE2b-256 | 98b752c4907c61611efe94f9ca4e106a8c2afcac12b2d1de16c30383acaaf4d6 |