Python wrapper for damo, a set of fast and robust hash functions.
Project description
Damo-Embedding
Quick Install
pip install damo-embedding
Example
DeepFM
import torch
import torch.nn as nn
from damo_embedding import Embedding
class DeepFM(torch.nn.Module):
def __init__(
self,
emb_size: int,
fea_size: int,
hid_dims=[256, 128],
num_classes=1,
dropout=[0.2, 0.2],
**kwargs,
):
super(DeepFM, self).__init__()
self.emb_size = emb_size
self.fea_size = fea_size
initializer = {
"name": "truncate_normal",
"mean": float(kwargs.get("mean", 0.0)),
"stddev": float(kwargs.get("stddev", 0.0001)),
}
optimizer = {
"name": "adam",
"gamma": float(kwargs.get("gamma", 0.001)),
"beta1": float(kwargs.get("beta1", 0.9)),
"beta2": float(kwargs.get("beta2", 0.999)),
"lambda": float(kwargs.get("lambda", 0.0)),
"epsilon": float(kwargs.get("epsilon", 1e-8)),
}
self.w = Embedding(
1,
initializer=initializer,
optimizer=optimizer,
)
self.v = Embedding(
self.emb_size,
initializer=initializer,
optimizer=optimizer,
)
self.w0 = torch.zeros(1, dtype=torch.float32, requires_grad=True)
self.dims = [fea_size * emb_size] + hid_dims
self.layers = nn.ModuleList()
for i in range(1, len(self.dims)):
self.layers.append(nn.Linear(self.dims[i - 1], self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(dropout[i - 1]))
self.layers.append(nn.Linear(self.dims[-1], num_classes))
self.sigmoid = nn.Sigmoid()
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""forward
Args:
input (torch.Tensor): input tensor
Returns:
tensor.Tensor: deepfm forward values
"""
assert input.shape[1] == self.fea_size
w = self.w.forward(input)
v = self.v.forward(input)
square_of_sum = torch.pow(torch.sum(v, dim=1), 2)
sum_of_square = torch.sum(v * v, dim=1)
fm_out = (
torch.sum((square_of_sum - sum_of_square)
* 0.5, dim=1, keepdim=True)
+ torch.sum(w, dim=1)
+ self.w0
)
dnn_out = torch.flatten(v, 1)
for layer in self.layers:
dnn_out = layer(dnn_out)
out = fm_out + dnn_out
out = self.sigmoid(out)
return out
Save Model
from damo_embedding import save_model
model = DeepFM(8, 39)
save_model(model, "./", training=False)
Document
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
damo-embedding-1.1.8.tar.gz
(224.2 kB
view hashes)
Built Distributions
Close
Hashes for damo_embedding-1.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d363dbf57aa11508b44deba611e0f64847801aba90002e6ae6ebd4aa1cf95439 |
|
MD5 | 291f8cedb79ed93e3853e75467fb200d |
|
BLAKE2b-256 | d574ac3938abce51b495d189a0bb3382c17fbc8a2bda4ec13cdb96201eb45865 |
Close
Hashes for damo_embedding-1.1.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e84d7d91d0b5c1725fd6e95fe68d69706bcd14208a0e224be1dc006ef7882e59 |
|
MD5 | f557c21163a3d2971305d7084c6de0f2 |
|
BLAKE2b-256 | 21c7398a1a8b556e97a4451d9d3715ae608d0158e4b822d40f9ee60a0b0fce65 |
Close
Hashes for damo_embedding-1.1.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5cf5884f26ae31a6e66cd9976ed44d017186ace2b3bc8772922411a38574c80a |
|
MD5 | 772f0985f427b584fe4657a92f3ca675 |
|
BLAKE2b-256 | 6b2cf5cbd37d7ff899ec5ad72018f2135026fb443721c6bbeacd4a4b0c4691f1 |
Close
Hashes for damo_embedding-1.1.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | beefdd0b142f0b3774db7ec50ec3feb0c38986fe8d59530f53780bbf02417a1c |
|
MD5 | 8fde1fece3c11328d6798e6032121f83 |
|
BLAKE2b-256 | f6995d0b118234f844d81e3aec36195c2abde25e29301e316edef7ec316d588e |
Close
Hashes for damo_embedding-1.1.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e5d2d72a763dd31135d22524cdfdf3f7bead3472d01e379a179d252b96c84b78 |
|
MD5 | 6213bea8996751239514142afcbcd51f |
|
BLAKE2b-256 | 716518f6d8b1170546d185a1a04bfe7a527142733aabceef26da5066050d1023 |
Close
Hashes for damo_embedding-1.1.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cab50e6cb59119b442bf32396fbeb9620d44c8ab7e4d281bbffdeeb51c51009b |
|
MD5 | f155b181806e4e08b01a0c5981a4f477 |
|
BLAKE2b-256 | 6b4fcfd9367e7b3a601a274573cf9e70d54dc6cbc7da0530dc13eda7612d62f5 |
Close
Hashes for damo_embedding-1.1.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f98bab79830cacf7fe245fee39d4be43adcf6a8176ea3168241a01549f956167 |
|
MD5 | 1c1aa9936d117c9560dfd638a8a3e9f8 |
|
BLAKE2b-256 | 844c6843a3b7b2f34b6cee112cbc4d2618b868ce5fcdfb598d65edbb7c953fad |
Close
Hashes for damo_embedding-1.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 442a47b4011cefca8d40a480192991b95f90b5b3df8a39b4832299de50f84324 |
|
MD5 | 12b15a7c0a5a8d631c4bf946a1f2e3ab |
|
BLAKE2b-256 | 698b61be8a5f52096bd147970156fc22820e9cebafaf6675a49167c33b8dfe77 |
Close
Hashes for damo_embedding-1.1.8-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 14628c9dea8433a2262882b34df9c497f1c8746acd2a5a2fe4e91ce32a668e41 |
|
MD5 | 1457a19d1a78f07d485b524138eb2dc5 |
|
BLAKE2b-256 | c261319979b81fbe8d3db4794b15df93aa4df8b30f939956855c52778e5a65b7 |
Close
Hashes for damo_embedding-1.1.8-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d1d49904b08c4fd5b5f0df355dc16a00df8704f18584b779f2c5fc714445ed7f |
|
MD5 | a84a30a88ee6c5275e952c60db4ce8d1 |
|
BLAKE2b-256 | dd874734c2605798295b0ebde0544a518267ed311e978c6e746e90a931d6034b |