Skip to main content

A PyTorch library for custom floating point quantization with autograd support

Project description

Torch Floating Point

python-3.10 pytorch-1.13.1 release-version license

A PyTorch library for custom floating point quantization with autograd support. This library provides efficient implementations of custom floating point formats with automatic differentiation capabilities.

Features

  • Custom Floating Point Formats: Support for arbitrary floating point configurations (sign bits, exponent bits, mantissa bits, bias)
  • Autograd Support: Full PyTorch autograd integration for training with quantized weights
  • CUDA Support: GPU acceleration for both forward and backward passes
  • Straight-Through Estimator: Gradient-friendly quantization for training

Installation

From PyPI (Recommended)

pip install torch-floating-point

From Source

git clone https://github.com/SamirMoustafa/torch-floating-point.git
cd torch-floating-point
pip install -e .

Quick Start

import torch
from floating_point import FloatingPoint, Round

# Define a custom 8-bit floating point format (1 sign, 4 exponent, 3 mantissa bits)
fp8 = FloatingPoint(sign_bits=1, exponent_bits=4, mantissa_bits=3, bias=7, bits=8)

# Create a rounding function
rounder = Round(fp8)

# Create a tensor with gradients
x = torch.randn(10, requires_grad=True)

# Quantize the tensor
quantized = rounder(x)

# Use in training (gradients flow through)
loss = quantized.sum()
loss.backward()

print(f"Original: {x}")
print(f"Quantized: {quantized}")
print(f"Gradients: {x.grad}")

Training with Custom Floating Point Weights

import torch
import torch.nn as nn
from floating_point import FloatingPoint, Round

class FloatPointLinear(nn.Module):
    def __init__(self, in_features, out_features, fp_config):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))
        self.rounder = Round(fp_config)
    
    def forward(self, x):
        quantized_weight = self.rounder(self.weight)
        return torch.nn.functional.linear(x, quantized_weight, self.bias)

# Define custom floating point format
fp8 = FloatingPoint(sign_bits=1, exponent_bits=4, mantissa_bits=3, bias=7, bits=8)

# Create model with quantized weights
model = FloatPointLinear(10, 5, fp8)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Create simple data
x = torch.randn(32, 10)
y = torch.randn(32, 5)

# Training loop
for epoch in range(5):
    optimizer.zero_grad()
    
    # Forward pass
    output = model(x)
    loss = criterion(output, y)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch + 1}: Loss = {loss.item():.6f}")

Contributing

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Install development dependencies (make setup-dev)
  4. Make your changes
  5. Run tests (make test)
  6. Run linting (make lint)
  7. Commit your changes (git commit -m 'Add amazing feature')
  8. Push to the branch (git push origin feature/amazing-feature)
  9. Open a Pull Request

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use this library in your research, please cite:

@software{moustafa2025torchfloatingpoint,
  title={Torch Floating Point: A PyTorch library for custom floating point quantization},
  author={Samir Moustafa},
  year={2025},
  url={https://github.com/SamirMoustafa/torch-floating-point}
}

Support

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-floating-point-0.0.8.tar.gz (14.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_floating_point-0.0.8-cp310-cp310-manylinux_2_28_x86_64.whl (3.7 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

File details

Details for the file torch-floating-point-0.0.8.tar.gz.

File metadata

  • Download URL: torch-floating-point-0.0.8.tar.gz
  • Upload date:
  • Size: 14.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.18

File hashes

Hashes for torch-floating-point-0.0.8.tar.gz
Algorithm Hash digest
SHA256 261d3322f0eb11334c6dfff872f0de6c23dd91067f0f5cdac2031e84bbdf310c
MD5 161ea6943775733a6922ca1917ef5186
BLAKE2b-256 91d153bca37e5070f9526467db75827387afd2c13471ea41f4d992edfa7ba1f3

See more details on using hashes here.

File details

Details for the file torch_floating_point-0.0.8-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_floating_point-0.0.8-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 b2a2b042b8f9f990f35ffe1155d6d7a17f96d331ea0c7efe7696328c771a3c45
MD5 e6fe27d07119ffed1e1f14892e47d1c1
BLAKE2b-256 31ed5b53bcd8721e6adedb303c7260dc6a640ced918d2d8f1a540dcb44573c4e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page