A PyTorch library for custom floating point quantization with autograd support
Project description
Torch Floating Point
A PyTorch library for custom floating point quantization with autograd support. This library provides efficient implementations of custom floating point formats with automatic differentiation capabilities.
Features
- Custom Floating Point Formats: Support for arbitrary floating point configurations (sign bits, exponent bits, mantissa bits, bias)
- Autograd Support: Full PyTorch autograd integration for training with quantized weights
- CUDA Support: GPU acceleration for both forward and backward passes
- Multiple Precision: Support for various bit widths (4-bit, 8-bit, 16-bit, 32-bit)
- Straight-Through Estimator: Gradient-friendly quantization for training
- Comprehensive Testing: Extensive test suite covering differentiability and accuracy
Installation
From PyPI (Recommended)
pip install torch-floating-point
From Source
git clone https://github.com/SamirMoustafa/torch-floating-point.git
cd torch-floating-point
pip install -e .
Development Installation
git clone https://github.com/SamirMoustafa/torch-floating-point.git
cd torch-floating-point
pip install -e ".[dev,test]"
pre-commit install
Quick Start
import torch
from floating_point import FloatingPoint, Round
# Define a custom 8-bit floating point format (1 sign, 4 exponent, 3 mantissa bits)
fp8 = FloatingPoint(sign_bits=1, exponent_bits=4, mantissa_bits=3, bias=7, bits=8)
# Create a rounding function
rounder = Round(fp8)
# Create a tensor with gradients
x = torch.randn(10, requires_grad=True)
# Quantize the tensor
quantized = rounder(x)
# Use in training (gradients flow through)
loss = quantized.sum()
loss.backward()
print(f"Original: {x}")
print(f"Quantized: {quantized}")
print(f"Gradients: {x.grad}")
Examples
The project includes comprehensive examples in the examples/ directory:
Simple Rounding Example (examples/01_simple_rounding.py)
Demonstrates basic rounding functionality with different floating point formats:
- Compares FP4, FP8, and FP16 precision
- Shows quantization errors and differences
- Demonstrates range limitations and edge cases
- Handles subnormal values
Gradient Flow Example (examples/02_gradient_flow.py)
Demonstrates gradient flow through quantized operations:
- Shows Straight-Through Estimator (STE) in action
- Compares gradient flow across different formats
- Includes a complete training loop with quantized weights
- Analyzes gradient patterns and clipping behavior
Running Examples
# Run individual examples
python examples/01_simple_rounding.py
python examples/02_gradient_flow.py
# Run all examples
python examples/run_all_examples.py
Usage Examples
Custom Floating Point Configuration
from floating_point import FloatingPoint
# 4-bit floating point (1 sign, 2 exponent, 1 mantissa)
fp4 = FloatingPoint(sign_bits=1, exponent_bits=2, mantissa_bits=1, bias=1, bits=4)
# 8-bit floating point with custom max mantissa
fp8_custom = FloatingPoint(
sign_bits=1,
exponent_bits=4,
mantissa_bits=3,
bias=7,
bits=8,
max_mantissa_at_max_exponent=6, # Custom max mantissa
reserved_exponent=False # No reserved exponent for inf/nan
)
# 16-bit floating point (standard)
fp16 = FloatingPoint(sign_bits=1, exponent_bits=5, mantissa_bits=10, bias=15, bits=16)
Training with Quantized Weights
import torch
import torch.nn as nn
from floating_point import FloatingPoint, Round
class QuantizedLinear(nn.Module):
def __init__(self, in_features, out_features, fp_config):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.rounder = Round(fp_config)
def forward(self, x):
quantized_weight = self.rounder(self.weight)
return torch.nn.functional.linear(x, quantized_weight)
# Define quantization format
fp8 = FloatingPoint(sign_bits=1, exponent_bits=4, mantissa_bits=3, bias=7, bits=8)
# Create model with quantized weights
model = QuantizedLinear(784, 10, fp8)
optimizer = torch.optim.Adam(model.parameters())
# Training loop
for epoch in range(10):
# ... your training code ...
loss.backward()
optimizer.step()
Direct Function Usage
import torch
from floating_point import autograd
# Direct quantization function
x = torch.randn(100, requires_grad=True)
quantized = autograd(x, exponent_bits=4, mantissa_bits=3, bias=7)
# Gradients work automatically
loss = quantized.sum()
loss.backward()
Supported Formats
The library supports various floating point formats:
| Format | Sign Bits | Exponent Bits | Mantissa Bits | Bias | Total Bits |
|---|---|---|---|---|---|
| FP4 | 1 | 2 | 1 | 1 | 4 |
| FP8 | 1 | 4 | 3 | 7 | 8 |
| FP16 | 1 | 5 | 10 | 15 | 16 |
| BF16 | 1 | 8 | 7 | 127 | 16 |
| FP32 | 1 | 8 | 23 | 127 | 32 |
Development
Testing
The project includes two testing approaches:
- CI/CD Tests (GitHub Actions): Fast, lightweight tests that verify core functionality without heavy numerical computations
- Full Test Suite: Complete test coverage including all numerical precision tests (run locally or via manual workflow trigger)
To run the full test suite locally:
export LD_LIBRARY_PATH=$(python -c "import torch; print(torch.__file__)")/lib:$LD_LIBRARY_PATH
python -m pytest test/round.py test/data_types.py -v
Running Tests
# Run all tests
make test
# Run tests with coverage
make test-cov
# Run specific test file
python -m pytest test/round.py -v
Code Quality
# Run linting
make lint
# Format code
make format
# Run all checks
make full-check
Building
# Build extension
python setup.py build_ext --inplace
# Build package
make build
# Clean build artifacts
make clean
Contributing
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Install development dependencies (
make setup-dev) - Make your changes
- Run tests (
make test) - Run linting (
make lint) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
License
This project is licensed under the MIT License - see the LICENSE file for details.
Citation
If you use this library in your research, please cite:
@software{torch_floating_point,
title={Torch Floating Point: A PyTorch library for custom floating point quantization},
author={Samir Moustafa},
year={2024},
url={https://github.com/SamirMoustafa/torch-floating-point}
}
Acknowledgments
- PyTorch team for the excellent autograd system
- The PyTorch C++ extension community for guidance on extension development
- Contributors and users of this library
Support
- Issues: GitHub Issues
- Discussions: GitHub Discussions
- Email: samir.moustafa.97@gmail.com
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch-floating-point-0.0.6.tar.gz.
File metadata
- Download URL: torch-floating-point-0.0.6.tar.gz
- Upload date:
- Size: 16.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ff904bcc1f8c5de1b0a9c903174e98d9be264cbe64fa14163955444b19324482
|
|
| MD5 |
fe0f18c92b86c5728f7f0036a9406b8c
|
|
| BLAKE2b-256 |
23b3ebb8e0533c3253465c554b34c91edb5e2d9842dcd588284390dca42fd789
|
File details
Details for the file torch_floating_point-0.0.6-cp310-cp310-manylinux_2_28_x86_64.whl.
File metadata
- Download URL: torch_floating_point-0.0.6-cp310-cp310-manylinux_2_28_x86_64.whl
- Upload date:
- Size: 3.7 MB
- Tags: CPython 3.10, manylinux: glibc 2.28+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fc84d4932c6e79785efba4707e98d803cee3f431f6748fbe7198ce2b398f8af2
|
|
| MD5 |
c4746bbfc329bc807b6cf340c74b506e
|
|
| BLAKE2b-256 |
18bf6212c63fedbd28401c2b4632494e1b8c51bec28c13b50a4ba666f901c27e
|