Skip to main content

Delightful and useful neural networks models, including OrdinalRegressionLoss, etc.

Project description

handy-nn

Delightful and useful neural networks models, including OrdinalRegressionLoss, etc.

Install

$ pip install handy-nn

Usage

from handy_nn import OrdinalRegressionLoss

# Initialize the loss function
num_classes = 5
criterion = OrdinalRegressionLoss(num_classes)

# For training
logits = model(inputs)  # Shape: (batch_size, 1)
loss = criterion(logits, targets)
loss.backward()  # shape: torch.Size([])

# To get class probabilities
probas = criterion.predict_probas(logits)  # Shape: (batch_size, num_classes)

Shapes

Variable Shape
logits (batch_size, 1)
targets (batch_size,)
loss torch.Size([])
probas (batch_size, num_classes)

APIs

OrdinalRegressionLoss(num_classes, learn_thresholds=True, init_scale=2.0)

  • num_classes int: Number of ordinal classes (ranks)
  • learn_thresholds bool=True: Whether to learn threshold parameters or use fixed ones, defaults to True.
  • init_scale float=2.0: Scale for initializing thresholds, defaults to 2.0

Creates the loss function for ordinal regression.

The goal of ordinal regression is to model the relationship between one or more independent variables and an ordinal dependent variable. It predicts the probability that an observation falls into a specific ordinal category or a category higher than a certain threshold. This is particularly useful in fields like social sciences, medicine, and customer surveys where outcomes are often ordinal.

TrendAwareLoss()

criterion = TrendAwareLoss()
loss = criterion(logits, targets)
loss.backward()
  • logits torch.Tensor of shape (batch_size, num_classes), the logits of your model output
  • targets torch.Tensor of either (batch_size,) (label classifications) or (batch_size, num_classes) (one-hot tensors)

TrendAwareLoss penalizes "too-early / too-late" misclassification inside a label segment more heavily by multiplying per-sample cross-entropy with the segment remaining-length weight.

This loss function is useful for those situations where misclassification leads to an indirect loss, such as financial trading, etc.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

handy_nn-0.0.5.tar.gz (6.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

handy_nn-0.0.5-py3-none-any.whl (6.0 kB view details)

Uploaded Python 3

File details

Details for the file handy_nn-0.0.5.tar.gz.

File metadata

  • Download URL: handy_nn-0.0.5.tar.gz
  • Upload date:
  • Size: 6.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for handy_nn-0.0.5.tar.gz
Algorithm Hash digest
SHA256 f08767fd3bdc34fdb309bfdd7f4eb5a15a1b0f61316f5906e9984c7acc0aab6e
MD5 52d16ec87b6a0b1688fefb2916013fd0
BLAKE2b-256 bf3ff4e3f532cf8ffefbca36bff350c51dceab10e6ab22f4b5fa3877c3f0a3e2

See more details on using hashes here.

File details

Details for the file handy_nn-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: handy_nn-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 6.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.11

File hashes

Hashes for handy_nn-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 052ae302d58fb4ddfbc7bce6fb5c173ef67638911e9ef3ac288df40f5d5cb6a6
MD5 538ed4e375d781e1003f31b4ec7cacee
BLAKE2b-256 6147aa93d27118f26a4cd6db2178bd76dbba3c9f41b964cf9df8a9008caa7060

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page