Generalized semantic regression with a BERT base.
Project description
generalized-semantic-regression
RiskBERT is a significant step forward, making it easier than ever to incorporate text fragments into various applications, such as insurance frequency and severity models, or other GLM-based models. Feel free to explore and utilize RiskBERT for your text analysis needs.
To learn more about the RiskBERT implementation read this article: https://www.thebigdatablog.com/generalized-semantic-regression-using-contextual-embeddings/
Example:
pip install RiskBERT
from transformers import AutoTokenizer
import torch
from RiskBERT import glmModel, RiskBertModel
from RiskBERT import trainer, evaluate_model
from RiskBERT.simulation.data_functions import Data
from RiskBERT.utils import DataConstructor
# Set device to gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Init the model
model_dataset = Data(20000, scores=torch.tensor([[0.2],[0.4]]), weigth=5)
pre_model= "distilbert-base-uncased"
model = RiskBertModel(model=pre_model, input_dim=2, dropout=0.4, freeze_bert=True, mode="CLS")
tokenizer = AutoTokenizer.from_pretrained(pre_model)
# Train the model
model, Total_Loss, Validation_Loss, Test_Loss = trainer(model =model,
model_dataset=model_dataset,
epochs=100,
batch_size=1000,
evaluate_fkt=evaluate_model,
tokenizer=tokenizer,
optimizer=torch.optim.SGD(model.parameters(), lr=0.001),
device = device
)
# Predict from the model
my_data = DataConstructor(
sentences=[["Dies ist ein Test"],["Hallo Welt", "RiskBERT ist das Beste"]],
covariates=[[1,5],[2,6]],
tokenizer= tokenizer).prepare_for_model()
my_prediction=model(**my_data)
Upload to pip
python -m pip install build twine
python -m build
twine check dist/*
twine upload dist/*`
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file RiskBERT-0.0.9.tar.gz.
File metadata
- Download URL: RiskBERT-0.0.9.tar.gz
- Upload date:
- Size: 21.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
60497ca774844542f2984c3b1e1a47c22e563ceac72f596aac2899ff5e0ede08
|
|
| MD5 |
a83a0fbc87dbcb32e4d0efde1d4b6975
|
|
| BLAKE2b-256 |
15c0c92d553c20a8e3878ccc44d9105839a20b24b7fcef1a8bf335c134908874
|
File details
Details for the file RiskBERT-0.0.9-py3-none-any.whl.
File metadata
- Download URL: RiskBERT-0.0.9-py3-none-any.whl
- Upload date:
- Size: 16.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.9.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
720f02e9303eb36911c36d73916111b29858aed4f1e0298902480b09710a8c3c
|
|
| MD5 |
50a9dd37387190fa30466f953d5585b0
|
|
| BLAKE2b-256 |
35a0d5768ccba22360beca65a5b13a7a7f78da337ce74f334cef2e019ea3b837
|