Python wrapper for damo, a set of fast and robust hash functions.
Project description
Damo-Embedding
Quick Install
pip install damo-embedding
Example
DeepFM
import torch
import torch.nn as nn
from damo_embedding import Embedding
class DeepFM(torch.nn.Module):
def __init__(
self,
emb_size: int,
fea_size: int,
hid_dims=[256, 128],
num_classes=1,
dropout=[0.2, 0.2],
**kwargs,
):
super(DeepFM, self).__init__()
self.emb_size = emb_size
self.fea_size = fea_size
initializer = {
"name": "truncate_normal",
"mean": float(kwargs.get("mean", 0.0)),
"stddev": float(kwargs.get("stddev", 0.0001)),
}
optimizer = {
"name": "adam",
"gamma": float(kwargs.get("gamma", 0.001)),
"beta1": float(kwargs.get("beta1", 0.9)),
"beta2": float(kwargs.get("beta2", 0.999)),
"lambda": float(kwargs.get("lambda", 0.0)),
"epsilon": float(kwargs.get("epsilon", 1e-8)),
}
self.w = Embedding(
1,
initializer=initializer,
optimizer=optimizer,
)
self.v = Embedding(
self.emb_size,
initializer=initializer,
optimizer=optimizer,
)
self.w0 = torch.zeros(1, dtype=torch.float32, requires_grad=True)
self.dims = [fea_size * emb_size] + hid_dims
self.layers = nn.ModuleList()
for i in range(1, len(self.dims)):
self.layers.append(nn.Linear(self.dims[i - 1], self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.BatchNorm1d(self.dims[i]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(dropout[i - 1]))
self.layers.append(nn.Linear(self.dims[-1], num_classes))
self.sigmoid = nn.Sigmoid()
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""forward
Args:
input (torch.Tensor): input tensor
Returns:
tensor.Tensor: deepfm forward values
"""
assert input.shape[1] == self.fea_size
w = self.w.forward(input)
v = self.v.forward(input)
square_of_sum = torch.pow(torch.sum(v, dim=1), 2)
sum_of_square = torch.sum(v * v, dim=1)
fm_out = (
torch.sum((square_of_sum - sum_of_square)
* 0.5, dim=1, keepdim=True)
+ torch.sum(w, dim=1)
+ self.w0
)
dnn_out = torch.flatten(v, 1)
for layer in self.layers:
dnn_out = layer(dnn_out)
out = fm_out + dnn_out
out = self.sigmoid(out)
return out
Save Model
from damo_embedding import save_model
model = DeepFM(8, 39)
save_model(model, "./", training=False)
Document
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
damo-embedding-1.1.10.tar.gz
(224.4 kB
view hashes)
Built Distributions
Close
Hashes for damo_embedding-1.1.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bafb81fa0981aee6597e57a6617fb506403d4803057e7510718de57badb6c063 |
|
MD5 | 51ac4733622020a6e9bd24bb1a23abe7 |
|
BLAKE2b-256 | 69cf3fafae7e1c88e2ae3d8dfcff56acb0e3825aaf355afbcc205e6ebb24f6f8 |
Close
Hashes for damo_embedding-1.1.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 64202e8a230700d9807dab7cc8eb31351e2302db3d27d726c727d021a0249771 |
|
MD5 | d233499b9315049fba03691bbc0cc570 |
|
BLAKE2b-256 | dae5c5f141289de5945932a2553e224eada5fd249acd4cfc839b54af908cbdf4 |
Close
Hashes for damo_embedding-1.1.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 87eaf7ac122161a307495b3c23ccb4b60be1839cf9cca178823cab434a84ac7b |
|
MD5 | 1eabac66e4ba25b86f39c468a9f14bda |
|
BLAKE2b-256 | 7404c7a87241e47e594f4eebbd731e36dc3a314be62d135686a623ac321b6a31 |
Close
Hashes for damo_embedding-1.1.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3849692bc6ae867017aac002218d329122b72e643edf7391107884b25cf0a337 |
|
MD5 | 7a316414f5cf5a4191d1753a0992bfec |
|
BLAKE2b-256 | fdd967ed80876b4022f94bec8dbd54d3e0a5bd019f4789a876fc6cecf77b7eff |
Close
Hashes for damo_embedding-1.1.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 602e142c857431b6e7b0afe8c055bc86e6379824ee814627946c989e07f18e88 |
|
MD5 | b245fd9133bf9fa9f41206d9dfeeccfc |
|
BLAKE2b-256 | 063893b0dc42c565a47842a4f7f3d74c6c3357f6cde4bf1608f095b4d87a2d75 |
Close
Hashes for damo_embedding-1.1.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1778dad7b2e1fe1d0b5c828e2c6fdc46e1841801ccbdcde6d3c25ab5db34fea8 |
|
MD5 | dc2f4eac82be19fb95e20fcd772e038a |
|
BLAKE2b-256 | 32f1fd974c78d87f40e0cf186884d50ea2443f094c32b31ef800e52b7fe111f4 |
Close
Hashes for damo_embedding-1.1.10-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 471c89c147542399d8c24779b5ea9679966542bc20bbee4b6dd87eab5d30ef02 |
|
MD5 | 5bc5904f73f8b26fa556da1722c1b6f6 |
|
BLAKE2b-256 | 6bf42ef58af688091f09efdd9e8d8c3e816870940815a08b443619fae46c798e |
Close
Hashes for damo_embedding-1.1.10-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ff53f78fcc32681830d7b69a06a39a2eefc08c401dac1bfab7a13a3518e77095 |
|
MD5 | adc16b17cb7915455819d8a8b929954c |
|
BLAKE2b-256 | 9a963c9292b92718acfe9ea93bf6d12f6468260e226df36f44aa644c29f1394d |
Close
Hashes for damo_embedding-1.1.10-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e73d6f9ef00c3cf6e37ad4d4a51dbd07dd95fce695a0e141b961d1547131819d |
|
MD5 | 9b00796173c4303ae037b0435f3624c1 |
|
BLAKE2b-256 | 745544b6a61a4d0086c70434f068c01a10746ee2c99896438dfe97cc64a2a31b |
Close
Hashes for damo_embedding-1.1.10-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e80672db92a30af92dc3ad00ad877c77737e451454b9c02bb436bf257b2a9364 |
|
MD5 | d2023d1995c541397416d12c8c3d0dc2 |
|
BLAKE2b-256 | 602248ef4d00704d98d49f6b2cb51cf1d9239e01c6794c8dc68a19036932e95b |