Skip to main content

A PyTorch library for custom floating point quantization with autograd support

Project description

Torch Floating Point

A PyTorch library for custom floating point quantization with autograd support. This library provides efficient implementations of custom floating point formats with automatic differentiation capabilities.

Features

  • Custom Floating Point Formats: Support for arbitrary floating point configurations (sign bits, exponent bits, mantissa bits, bias)
  • Autograd Support: Full PyTorch autograd integration for training with quantized weights
  • CUDA Support: GPU acceleration for both forward and backward passes
  • Multiple Precision: Support for various bit widths (4-bit, 8-bit, 16-bit, 32-bit)
  • Straight-Through Estimator: Gradient-friendly quantization for training
  • Comprehensive Testing: Extensive test suite covering differentiability and accuracy

Installation

From PyPI (Recommended)

pip install torch-floating-point

From Source

git clone https://github.com/SamirMoustafa/torch-floating-point.git
cd torch-floating-point
pip install -e .

Development Installation

git clone https://github.com/SamirMoustafa/torch-floating-point.git
cd torch-floating-point
pip install -e ".[dev,test]"
pre-commit install

Quick Start

import torch
from floating_point import FloatingPoint, Round

# Define a custom 8-bit floating point format (1 sign, 4 exponent, 3 mantissa bits)
fp8 = FloatingPoint(sign_bits=1, exponent_bits=4, mantissa_bits=3, bias=7, bits=8)

# Create a rounding function
rounder = Round(fp8)

# Create a tensor with gradients
x = torch.randn(10, requires_grad=True)

# Quantize the tensor
quantized = rounder(x)

# Use in training (gradients flow through)
loss = quantized.sum()
loss.backward()

print(f"Original: {x}")
print(f"Quantized: {quantized}")
print(f"Gradients: {x.grad}")

Examples

The project includes comprehensive examples in the examples/ directory:

Simple Rounding Example (examples/01_simple_rounding.py)

Demonstrates basic rounding functionality with different floating point formats:

  • Compares FP4, FP8, and FP16 precision
  • Shows quantization errors and differences
  • Demonstrates range limitations and edge cases
  • Handles subnormal values

Gradient Flow Example (examples/02_gradient_flow.py)

Demonstrates gradient flow through quantized operations:

  • Shows Straight-Through Estimator (STE) in action
  • Compares gradient flow across different formats
  • Includes a complete training loop with quantized weights
  • Analyzes gradient patterns and clipping behavior

Running Examples

# Run individual examples
python examples/01_simple_rounding.py
python examples/02_gradient_flow.py

# Run all examples
python examples/run_all_examples.py

Usage Examples

Custom Floating Point Configuration

from floating_point import FloatingPoint

# 4-bit floating point (1 sign, 2 exponent, 1 mantissa)
fp4 = FloatingPoint(sign_bits=1, exponent_bits=2, mantissa_bits=1, bias=1, bits=4)

# 8-bit floating point with custom max mantissa
fp8_custom = FloatingPoint(
    sign_bits=1, 
    exponent_bits=4, 
    mantissa_bits=3, 
    bias=7, 
    bits=8,
    max_mantissa_at_max_exponent=6,  # Custom max mantissa
    reserved_exponent=False  # No reserved exponent for inf/nan
)

# 16-bit floating point (standard)
fp16 = FloatingPoint(sign_bits=1, exponent_bits=5, mantissa_bits=10, bias=15, bits=16)

Training with Quantized Weights

import torch
import torch.nn as nn
from floating_point import FloatingPoint, Round

class QuantizedLinear(nn.Module):
    def __init__(self, in_features, out_features, fp_config):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.rounder = Round(fp_config)
    
    def forward(self, x):
        quantized_weight = self.rounder(self.weight)
        return torch.nn.functional.linear(x, quantized_weight)

# Define quantization format
fp8 = FloatingPoint(sign_bits=1, exponent_bits=4, mantissa_bits=3, bias=7, bits=8)

# Create model with quantized weights
model = QuantizedLinear(784, 10, fp8)
optimizer = torch.optim.Adam(model.parameters())

# Training loop
for epoch in range(10):
    # ... your training code ...
    loss.backward()
    optimizer.step()

Direct Function Usage

import torch
from floating_point import autograd

# Direct quantization function
x = torch.randn(100, requires_grad=True)
quantized = autograd(x, exponent_bits=4, mantissa_bits=3, bias=7)

# Gradients work automatically
loss = quantized.sum()
loss.backward()

Supported Formats

The library supports various floating point formats:

Format Sign Bits Exponent Bits Mantissa Bits Bias Total Bits
FP4 1 2 1 1 4
FP8 1 4 3 7 8
FP16 1 5 10 15 16
BF16 1 8 7 127 16
FP32 1 8 23 127 32

Development

Testing

The project includes two testing approaches:

  1. CI/CD Tests (GitHub Actions): Fast, lightweight tests that verify core functionality without heavy numerical computations
  2. Full Test Suite: Complete test coverage including all numerical precision tests (run locally or via manual workflow trigger)

To run the full test suite locally:

export LD_LIBRARY_PATH=$(python -c "import torch; print(torch.__file__)")/lib:$LD_LIBRARY_PATH
python -m pytest test/round.py test/data_types.py -v

Running Tests

# Run all tests
make test

# Run tests with coverage
make test-cov

# Run specific test file
python -m pytest test/round.py -v

Code Quality

# Run linting
make lint

# Format code
make format

# Run all checks
make full-check

Building

# Build extension
python setup.py build_ext --inplace

# Build package
make build

# Clean build artifacts
make clean

Contributing

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Install development dependencies (make setup-dev)
  4. Make your changes
  5. Run tests (make test)
  6. Run linting (make lint)
  7. Commit your changes (git commit -m 'Add amazing feature')
  8. Push to the branch (git push origin feature/amazing-feature)
  9. Open a Pull Request

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If you use this library in your research, please cite:

@software{torch_floating_point,
  title={Torch Floating Point: A PyTorch library for custom floating point quantization},
  author={Samir Moustafa},
  year={2024},
  url={https://github.com/SamirMoustafa/torch-floating-point}
}

Acknowledgments

  • PyTorch team for the excellent autograd system
  • The PyTorch C++ extension community for guidance on extension development
  • Contributors and users of this library

Support

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-floating-point-0.0.6.tar.gz (16.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_floating_point-0.0.6-cp310-cp310-manylinux_2_28_x86_64.whl (3.7 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

File details

Details for the file torch-floating-point-0.0.6.tar.gz.

File metadata

  • Download URL: torch-floating-point-0.0.6.tar.gz
  • Upload date:
  • Size: 16.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.18

File hashes

Hashes for torch-floating-point-0.0.6.tar.gz
Algorithm Hash digest
SHA256 ff904bcc1f8c5de1b0a9c903174e98d9be264cbe64fa14163955444b19324482
MD5 fe0f18c92b86c5728f7f0036a9406b8c
BLAKE2b-256 23b3ebb8e0533c3253465c554b34c91edb5e2d9842dcd588284390dca42fd789

See more details on using hashes here.

File details

Details for the file torch_floating_point-0.0.6-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_floating_point-0.0.6-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 fc84d4932c6e79785efba4707e98d803cee3f431f6748fbe7198ce2b398f8af2
MD5 c4746bbfc329bc807b6cf340c74b506e
BLAKE2b-256 18bf6212c63fedbd28401c2b4632494e1b8c51bec28c13b50a4ba666f901c27e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page