Delightful and useful neural networks models, including OrdinalRegressionLoss, etc.
Project description
handy-nn
Delightful and useful neural networks models, including OrdinalRegressionLoss, etc.
Install
$ pip install handy-nn
Usage
from handy_nn import OrdinalRegressionLoss
# Initialize the loss function
num_classes = 5
criterion = OrdinalRegressionLoss(num_classes)
# For training
logits = model(inputs) # Shape: (batch_size, 1)
loss = criterion(logits, targets)
loss.backward() # shape: torch.Size([])
# To get class probabilities
probas = criterion.predict_probas(logits) # Shape: (batch_size, num_classes)
Shapes
| Variable | Shape |
|---|---|
logits |
(batch_size, 1) |
targets |
(batch_size,) or one-hot (batch_size, num_classes) |
loss |
torch.Size([]) |
probas |
(batch_size, num_classes) |
APIs
OrdinalRegressionLoss(num_classes, learn_thresholds=True, init_scale=2.0)
- num_classes
int: Number of ordinal classes (ranks) - learn_thresholds
bool=True: Whether to learn threshold parameters or use fixed ones, defaults toTrue. - init_scale
float=2.0: Scale for initializing thresholds, defaults to2.0
Creates the loss function for ordinal regression.
The goal of ordinal regression is to model the relationship between one or more independent variables and an ordinal dependent variable. It predicts the probability that an observation falls into a specific ordinal category or a category higher than a certain threshold. This is particularly useful in fields like social sciences, medicine, and customer surveys where outcomes are often ordinal.
TrendAwareLoss()
criterion = TrendAwareLoss()
loss = criterion(logits, targets)
loss.backward()
TrendAwareLoss penalizes "too-early / too-late" misclassification inside a label segment more heavily by multiplying per-sample cross-entropy with the segment remaining-length weight.
This loss function is useful for those situations where misclassification leads to an indirect loss, such as financial trading, etc.
License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file handy_nn-0.0.6.tar.gz.
File metadata
- Download URL: handy_nn-0.0.6.tar.gz
- Upload date:
- Size: 6.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d9ba501f4be4a7d3abff83a6c8885b0519e91037992ad1b799fc623fb7ec62d4
|
|
| MD5 |
8e0e51beeff0c6c6969e6c9577b664b2
|
|
| BLAKE2b-256 |
ee6c236adbdbfa5ccb97842190813255c0c8e3a62e7fac09ed3c384b47c316e6
|
File details
Details for the file handy_nn-0.0.6-py3-none-any.whl.
File metadata
- Download URL: handy_nn-0.0.6-py3-none-any.whl
- Upload date:
- Size: 6.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ab78f85dc101ed59ef913a76bcbb808f5aa0a86ae8dd1063efa1083423f2876b
|
|
| MD5 |
7c2a46b8cd4d931ea2235375a0e8a668
|
|
| BLAKE2b-256 |
c40c1403ef048f9da11d752f8d02600b721ae8b5593d493dae5a9e2e0519a62a
|