Yet another Bayesian factorization machines.
Project description
myFM
myFM is an implementation of Bayesian Factorization Machines based on Gibbs sampling, which I believe is a wheel worth reinventing.
Currently this supports most options for libFM MCMC engine, such as
- Grouping of input variables (
-meta
option of libFM) - Relation Data format (See the paper "Scaling Factorization Machines to relational data")
There are also functionalities not present in libFM:
- The gibbs sampler for Ordered probit regression [5] implementing Metropolis-within-Gibbs scheme of [6].
- Variational inference for regression and binary classification.
Tutorial and reference doc is provided at https://myfm.readthedocs.io/en/latest/.
Installation
The package is pip-installable.
pip install myfm
There are binaries for major operating systems.
If you are working with less popular OS/architecture, pip will attempt to build myFM from the source (you need a decent C++ compiler!). In that case, in addition to installing python dependencies (numpy
, scipy
, pandas
, ...), the above command will automatically download eigen (ver 3.4.0) to its build directory and use it during the build.
Examples
A Toy example
This example is taken from pyfm with some modification.
import myfm
from sklearn.feature_extraction import DictVectorizer
import numpy as np
train = [
{"user": "1", "item": "5", "age": 19},
{"user": "2", "item": "43", "age": 33},
{"user": "3", "item": "20", "age": 55},
{"user": "4", "item": "10", "age": 20},
]
v = DictVectorizer()
X = v.fit_transform(train)
print(X.toarray())
# print
# [[ 19. 0. 0. 0. 1. 1. 0. 0. 0.]
# [ 33. 0. 0. 1. 0. 0. 1. 0. 0.]
# [ 55. 0. 1. 0. 0. 0. 0. 1. 0.]
# [ 20. 1. 0. 0. 0. 0. 0. 0. 1.]]
y = np.asarray([0, 1, 1, 0])
fm = myfm.MyFMClassifier(rank=4)
fm.fit(X,y)
fm.predict(v.transform({"user": "1", "item": "10", "age": 24}))
A Movielens-100k Example
This example will require pandas
and scikit-learn
. movielens100k_loader
is present in examples/movielens100k_loader.py
.
You will be able to obtain a result comparable to SOTA algorithms like GC-MC. See examples/ml-100k.ipynb
for the detailed version.
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from sklearn import metrics
import myfm
from myfm.utils.benchmark_data import MovieLens100kDataManager
data_manager = MovieLens100kDataManager()
df_train, df_test = data_manager.load_rating_predefined_split(
fold=3
) # Note the dependence on the fold
def test_myfm(df_train, df_test, rank=8, grouping=None, n_iter=100, samples=95):
explanation_columns = ["user_id", "movie_id"]
ohe = OneHotEncoder(handle_unknown="ignore")
X_train = ohe.fit_transform(df_train[explanation_columns])
X_test = ohe.transform(df_test[explanation_columns])
y_train = df_train.rating.values
y_test = df_test.rating.values
fm = myfm.MyFMRegressor(rank=rank, random_seed=114514)
if grouping:
# specify how columns of X_train are grouped
group_shapes = [len(category) for category in ohe.categories_]
assert sum(group_shapes) == X_train.shape[1]
else:
group_shapes = None
fm.fit(
X_train,
y_train,
group_shapes=group_shapes,
n_iter=n_iter,
n_kept_samples=samples,
)
prediction = fm.predict(X_test)
rmse = ((y_test - prediction) ** 2).mean() ** 0.5
mae = np.abs(y_test - prediction).mean()
print("rmse={rmse}, mae={mae}".format(rmse=rmse, mae=mae))
return fm
# basic regression
test_myfm(df_train, df_test, rank=8)
# rmse=0.90321, mae=0.71164
# with grouping
fm = test_myfm(df_train, df_test, rank=8, grouping=True)
# rmse=0.89594, mae=0.70481
Examples for Relational Data format
Below is a toy movielens-like example that utilizes relational data format proposed in [3].
This example, however, is too simplistic to exhibit the computational advantage of this data format. For an example with drastically reduced computational complexity, see examples/ml-100k-extended.ipynb
;
import pandas as pd
import numpy as np
from myfm import MyFMRegressor, RelationBlock
from sklearn.preprocessing import OneHotEncoder
users = pd.DataFrame([
{'user_id': 1, 'age': '20s', 'married': False},
{'user_id': 2, 'age': '30s', 'married': False},
{'user_id': 3, 'age': '40s', 'married': True}
]).set_index('user_id')
movies = pd.DataFrame([
{'movie_id': 1, 'comedy': True, 'action': False },
{'movie_id': 2, 'comedy': False, 'action': True },
{'movie_id': 3, 'comedy': True, 'action': True}
]).set_index('movie_id')
ratings = pd.DataFrame([
{'user_id': 1, 'movie_id': 1, 'rating': 2},
{'user_id': 1, 'movie_id': 2, 'rating': 5},
{'user_id': 2, 'movie_id': 2, 'rating': 4},
{'user_id': 2, 'movie_id': 3, 'rating': 3},
{'user_id': 3, 'movie_id': 3, 'rating': 3},
])
user_ids, user_indices = np.unique(ratings.user_id, return_inverse=True)
movie_ids, movie_indices = np.unique(ratings.movie_id, return_inverse=True)
user_ohe = OneHotEncoder(handle_unknown='ignore').fit(users.reset_index()) # include user id as feature
movie_ohe = OneHotEncoder(handle_unknown='ignore').fit(movies.reset_index())
X_user = user_ohe.transform(
users.reindex(user_ids).reset_index()
)
X_movie = movie_ohe.transform(
movies.reindex(movie_ids).reset_index()
)
block_user = RelationBlock(user_indices, X_user)
block_movie = RelationBlock(movie_indices, X_movie)
fm = MyFMRegressor(rank=2).fit(None, ratings.rating, X_rel=[block_user, block_movie])
prediction_df = pd.DataFrame([
dict(user_id=user_id,movie_id=movie_id,
user_index=user_index, movie_index=movie_index)
for user_index, user_id in enumerate(user_ids)
for movie_index, movie_id in enumerate(movie_ids)
])
predicted_rating = fm.predict(None, [
RelationBlock(prediction_df.user_index, X_user),
RelationBlock(prediction_df.movie_index, X_movie)
])
prediction_df['prediction'] = predicted_rating
print(
prediction_df.merge(ratings.rename(columns={'rating':'ground_truth'}), how='left')
)
References
- Rendle, Steffen. "Factorization machines." 2010 IEEE International Conference on Data Mining. IEEE, 2010.
- Rendle, Steffen. "Factorization machines with libfm." ACM Transactions on Intelligent Systems and Technology (TIST) 3.3 (2012): 57.
- Rendle, Steffen. "Scaling factorization machines to relational data." Proceedings of the VLDB Endowment. Vol. 6. No. 5. VLDB Endowment, 2013.
- Bayer, Immanuel. "fastfm: A library for factorization machines." arXiv preprint arXiv:1505.00641 (2015).
- Albert, James H., and Siddhartha Chib. "Bayesian analysis of binary and polychotomous response data." Journal of the American statistical Association 88.422 (1993): 669-679.
- Albert, James H., and Siddhartha Chib. "Sequential ordinal modeling with applications to survival data." Biometrics 57.3 (2001): 829-836.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distributions
Hashes for myfm-0.3.6-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26fe23f9210e89886b7599d265ca2dfb48247738e071519cbf5eb56c04bee6c0 |
|
MD5 | 72f3bf37c4448824b517b85853c6f217 |
|
BLAKE2b-256 | aac1fe5d8b89f7c4ea98ae0d9baf61d9028394bef151dfd29fb33bc9b58c3690 |
Hashes for myfm-0.3.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1058ead3923b60dfc0859779500fd592485365ae936b6537b3e1b7fde6059657 |
|
MD5 | 852a565d71790aa45b94ab5837828dae |
|
BLAKE2b-256 | a1c35e0429c60ba77a2711e9443d624127649515e3ce90eea9f9163d88092f0b |
Hashes for myfm-0.3.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c1d847afc9e21f996f545d389d051b46f8e3b15df88add6297a9f0815ff90d31 |
|
MD5 | 08153a21edae6e1a532ee7ca59172da4 |
|
BLAKE2b-256 | d1312322250f582f6b3a89a0a5de3e5305558d05cb75ebd10d11424d3c2ddda0 |
Hashes for myfm-0.3.6-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0c84228123f956f98e479fea2656fb19f9fa60b9c040027bf9fb5bf825aaaf9b |
|
MD5 | d6648098148628314aa2515ab139c70f |
|
BLAKE2b-256 | ebf20cfdbea6a4675ae16a00ed9d02620d529aa820ec784c24adfab225614020 |
Hashes for myfm-0.3.6-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf5630bbbb0850a6edb136330ea07b31917cc25161df81d6174c282b27bc8e03 |
|
MD5 | c38f28307b081c97bc1f6f692ea43528 |
|
BLAKE2b-256 | 397077441f79441b34bfedd16ebd7b374cbbc5664410de28189dbb57d05f2970 |
Hashes for myfm-0.3.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 213526d20c437c9116fb25c1a1eadaa8a7a920b7d75384ab795107db846238f2 |
|
MD5 | df3d7046d7eda9d807a5fd5f4b5ea97d |
|
BLAKE2b-256 | 6ba8bfc7e472052d6656ce7c14b1a3affdd6f12c0e1c28b1385ded3bb181af52 |
Hashes for myfm-0.3.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b0d80e5ead5a13231617903b78731774fe8c1127def87647863be904e8b242d1 |
|
MD5 | 831e355cc914f89234f9a75851719896 |
|
BLAKE2b-256 | ec10e0e38648241dd33c6b5a93a2417af622fadd1934f6805c1c45e3ad98d615 |
Hashes for myfm-0.3.6-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e5c2e9e01a195c85617b0436f58cb6568cd83b4862af3f40337ee07f85031168 |
|
MD5 | db90db873bb051620db3e7ab25cf09a9 |
|
BLAKE2b-256 | 55ceac55cc51fed6ac7201b14c7949f11fe51fd857c7bd6df7b3926546453d62 |
Hashes for myfm-0.3.6-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | dd4a84fc720bf653e4b9fe96246abb9d1e00543445ae0dace655ae643a9d6dcc |
|
MD5 | b2c1bec3b85b2d336bd8abfba54ddbbb |
|
BLAKE2b-256 | 25ccaaa15ab948750f6d55cd0195f11747c11b8edb838cddb8f516806d6af263 |
Hashes for myfm-0.3.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d15f7f92973f188d294a76511bbedc1066ebce83f32b907a1cd2c903c7ba318c |
|
MD5 | f31549137e0bacbfb4c6f6acb08cea6a |
|
BLAKE2b-256 | 0661d0ade05fdb4583bce51c48eaa6e10c662b8deace371a277ac414aa21b92c |
Hashes for myfm-0.3.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6bf8a3423c2fd2fe89f8e23d0a144d3773e7c339ced2314490440af724c0fd5b |
|
MD5 | f96b86cd91d3c90b63f4e10320d1951d |
|
BLAKE2b-256 | b7f3de28b3f62647a56f60872b0665102898ee02e95d4216864819a9a41c447f |
Hashes for myfm-0.3.6-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3fbb11dd8111a3db89be6a161c6863b04e8a7c688464ff6278f1572e65e11430 |
|
MD5 | 3396d9feca3f144815e17f8615b1addf |
|
BLAKE2b-256 | 31fe24f1b7b93c46ff6ec18e35bac25fbee3977adc9d29dbeaddde88786cad68 |
Hashes for myfm-0.3.6-cp38-cp38-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8477f8e21a4eb81a0119c713c5c2402c3cd76953e6897afaf2f0221eeba13d13 |
|
MD5 | 029b037c0f7252c0c658d13821a97e6f |
|
BLAKE2b-256 | 3d295287821f50752805b948fd4bb47706adac6634485649558570206aefa765 |
Hashes for myfm-0.3.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a5908588d09c18f160991c17efd95142b3c506c5a7f0c22c04bcb8b8871a2552 |
|
MD5 | 21619ab0f45dacea68ac8e9c739ecb6d |
|
BLAKE2b-256 | 8d7db88b5376de4fc05d8e0a02967331ab26547ba702645da4fd59840a9060d0 |
Hashes for myfm-0.3.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9abbd3450fd798c8a59c2cfc25eb1b7c24c7aa018883e39ec0cb24e404e23208 |
|
MD5 | 16f213d061715e528db784ec32222a15 |
|
BLAKE2b-256 | 2b4e00cad163afc31962954ec0e12912eb14d85905d6f95ced8ef173ffa6c9d9 |
Hashes for myfm-0.3.6-cp38-cp38-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 62f357a339dcb2c65f9d12737c7a10e0d12311692bc8f569558d6cb4e9355215 |
|
MD5 | 2ce2b9ca848c26870823940dc2ed77c6 |
|
BLAKE2b-256 | 0fe48d6bfc68da1b31d0eb0f9e62a81463ad51ab72f201cdba73da23a914dc55 |
Hashes for myfm-0.3.6-cp37-cp37m-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | eb996e705620eb0ad29f2592a2b0bc69f540d38e1f8b1ba3a1680f95bc172ac5 |
|
MD5 | a5d6bf10c8f8de7b52ff3c37a8b55027 |
|
BLAKE2b-256 | ab936696b53d1de0ad5fb92fb88475f0b7dd1c026771dc05601dd5023a867b5f |
Hashes for myfm-0.3.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 027e6a4c5cc7db77cb1f148c42d8eb852d11eec5d42200e382af3eea24bb2b75 |
|
MD5 | dad90c55371281297090d761eaa272c9 |
|
BLAKE2b-256 | 0abc785a9ac79b85fae9436b80dd797aa1235fa132c7630dff0045f21dd3ca7d |
Hashes for myfm-0.3.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 66016c2aa61db58be295c1739c77d74589de6cf2a6d448ae3de419ae17156280 |
|
MD5 | a808d8b84216365431dab4ed6b5aaceb |
|
BLAKE2b-256 | 73dd66f92f264726fa971e4a8aa2b054f654c978fbdaae877a9a242d45b84ebb |
Hashes for myfm-0.3.6-cp37-cp37m-macosx_10_9_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f7a8569c2cb26c4c2ca4fa839c60b66f48d05036ba4957a5526226a04451942a |
|
MD5 | b11034e031206c56cc4810445c32cde9 |
|
BLAKE2b-256 | 16157caee135e72d8706acddc38d30aeb2c1036849236b3c14831dc04464c405 |