Skip to main content

High-performance safetensors model loader

Project description

fastsafetensors is an efficient safetensors model loader. We introduced three major features to optimize model loading performance:

  1. Batched, lazy tensor instantiations
  2. GPU offloading for sharding, type conversions, and device pointer alignment.
  3. GPU Direct Storage enablement for file loading from storage to GPU memory

A major design difference from the original safetensors file loader is NOT to use mmap. It loads tensors on-demand with mmap'ed files, but unfortunately, it cannot fully utilize high-throughput I/O such as NVMe SSDs. So, we asynchronously transfer files in parallel to saturate storage throughput. Then, fastsafetensors lazily instantiates tensors at GPU device memory with DLPack.

Another design change is to offload sharding and other manipulations on tensors to GPUs. The original loader provides slicing for sharding at user programs before copying to device memory. However, it incurrs high CPU usages for host memory accesses. So, we introduce a special APIs to run sharding with torch.distributed collective operations such as broadcast and scatter. The offloading is also applied to other tensor manipulations such as type conversions.

The above two design can be naturally extended to utilize device-to-device data transfers with GPU Direct Storage. The technology helps to minimize copy overheads from NVMe SSDs to GPU memory with host CPU and memory bypassed.

Check more details in doc/overview.md

Dependencies

We currently test fastsafetensors only with python 3.11, pytorch 2.1, and cuda-12. Note: when using different versions of pytorch, you may require changes on build environments for libpytorch since it seems slightly changing ABIs.

Install from PyPi

pip install fastsfaetensors==0.1.2

Local installation

Prerequisites: Install torch, cuda, and numa headers

make install

Package build

Prerequisites: Install Docker (libtorch 2.1, cuda, and numa are automatically pulled)

make dist

Unit tests

make install-test # install stub'ed fastsafetensors without torch, cuda, and numa
make unittest

Basic API usages

SafeTensorsFileLoader is the primary entrypoint of the fastsafetensors library. To use it, pass either SingleGroup() for simple inference or ProcessGroup() (from torch.distributed) for tensor-parallel inference. The loader supports both CPU and CUDA devices, with optional GPU Direct Storage (GDS) support. You can specify the device and GDS settings using the device and nogds arguments, respectively. Note that if GDS is not available, the loader will fail to open files when nogds=False. For more information on enabling GDS, please refer to the NVIDIA documentation.

After creating a SafeTensorsFileLoader instance, first map target files and a rank using the .add_filenames() method. Then, call .copy_file_to_device() to trigger the actual file copies on aggregated GPU memory fragments and directly instantiate a group of Tensors. Once the files are loaded, you can retrieve a tensor using the .get_tensor() method. Additionally, you can obtain sharded tensors by .get_sharded(), which internally run collective operations in torch.distributed.

Important: To release the GPU memory allocated for tensors, you must explicitly call the .close() method. This is because Fastsafetensors allows multiple tensors to share a limited number of GPU memory fragments. As a result, it is the user's responsibility to ensure that all tensors are properly released before calling .close(), which will then safely release the underlying GPU memory.

Example: single run

examples/test_single.py:

import torch
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
loader = SafeTensorsFileLoader(SingleGroup, torch.device("cpu"), nogds=True, debug_log=True)
loader.add_filenames({0: ["a.safetensors", "b.safetensors"]}) # {rank: files}
fb = loader.copy_files_to_device()
tensor_a0 = fb.get_tensor(tensor_name="a0")
print(f"a0: {tensor_a0}")
loader.close()
cd examples
python test_single.py

Example output:

add_filenames 1: path=a.safetensors
[DEBUG] raw_device_pointer: raw_alloc: 0x7acf000, length=256, elapsed=3 us
[DEBUG] nogds_file_reader.submit_read: cudaHostAlloc, size=1048576, elapsed=10 us
[DEBUG] nogds_file_reader.submit_read #3, thread_id=1
[DEBUG] nogds_file_reader._thread: read (mmap=0), fd=4, offset=104, count=256, c=256, copy=13 us, cuda_copy=0 us
wait_io: tensor=a0
a0: tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.],
        [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
        [ 4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.],
        [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
        [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
        [ 7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.],
        [ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.],
        [10., 10., 10., 10., 10., 10., 10., 10.],
        [11., 11., 11., 11., 11., 11., 11., 11.],
        [12., 12., 12., 12., 12., 12., 12., 12.],
        [13., 13., 13., 13., 13., 13., 13., 13.],
        [14., 14., 14., 14., 14., 14., 14., 14.],
        [15., 15., 15., 15., 15., 15., 15., 15.]], dtype=torch.float16)
[DEBUG] ~nogds_file_reader: elapsed=28 us
[DEBUG] ~raw_device_pointer: torch_raw_delete: 0x7acf000, elapsed=0 us

Example: parallel run

examples/test_parallel.py:

import torch
import torch.distributed as dist
from fastsafetensors import SafeTensorsFileLoader
dist.init_process_group(backend="gloo")
dist.barrier()
pg = dist.group.WORLD
loader = SafeTensorsFileLoader(pg, torch.device("cpu"), nogds=True, debug_log=True)
loader.add_filenames({0: ["a.safetensors"], 1:["b.safetensors"]}) # {rank: files}
fb = loader.copy_files_to_device()
tensor_name = "a0" if pg.rank() == 0 else "b0"
dim = 0 if pg.rank() == 0 else 1
tensor = fb.get_sharded(tensor_name=tensor_name, dim=dim)
print(f"RANK {pg.rank()}: tensor_name={tensor}")
loader.close()

You can test the script with torchrun

cd examples
torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=0 test_parallel.py &
PIDS+=$($!)
torchrun --nnodes=2 --master_addr=0.0.0.0 --master_port=1234 --node_rank=1 test_parallel.py &
PIDS+=$($!)
wait ${PIDS[@]}

Example output:

add_filenames 1: path=a.safetensors
add_filenames 2: path=b.safetensors
[DEBUG] raw_device_pointer: raw_alloc: 0x7c83000, length=256, elapsed=0 us
[DEBUG] raw_device_pointer: raw_alloc: 0x6b0f000, length=256, elapsed=0 us
[DEBUG] nogds_file_reader.submit_read: cudaHostAlloc, size=1048576, elapsed=14 us
[DEBUG] nogds_file_reader.submit_read: cudaHostAlloc, size=1048576, elapsed=14 us
[DEBUG] nogds_file_reader.submit_read #3, thread_id=1
[DEBUG] nogds_file_reader.submit_read #3, thread_id=1
[DEBUG] nogds_file_reader._thread: read (mmap=0), fd=10, offset=104, count=256, c=256, copy=19 us, cuda_copy=1 us
[DEBUG] nogds_file_reader._thread: read (mmap=0), fd=10, offset=104, count=256, c=256, copy=18 us, cuda_copy=1 us
wait_io: tensor=b0
wait_io: tensor=a0
shuffle: scatter, tensor_name=a0, shape=torch.Size([16, 8])->torch.Size([8, 8]), self.rank=0, pg.rank()=1, rank_slices=[(slice(0, 8, 1),), (slice(8, 16, 1),)], len(scatter_list)=0
shuffle: scatter, tensor_name=a0, shape=torch.Size([16, 8])->torch.Size([8, 8]), self.rank=0, pg.rank()=0, rank_slices=[(slice(0, 8, 1),), (slice(8, 16, 1),)], len(scatter_list)=2
_get_tensor: free_dev_ptrs, lidx=0, src=a.safetensors
[DEBUG] ~raw_device_pointer: torch_raw_delete: 0x6b0f000, elapsed=0 us
shuffle: scatter, tensor_name=b0, shape=torch.Size([16, 8])->torch.Size([16, 4]), self.rank=1, pg.rank()=0, rank_slices=[(slice(None, None, None), slice(0, 4, 1)), (slice(None, None, None), slice(4, 8, 1))], len(scatter_list)=0
shuffle: scatter, tensor_name=b0, shape=torch.Size([16, 8])->torch.Size([16, 4]), self.rank=1, pg.rank()=1, rank_slices=[(slice(None, None, None), slice(0, 4, 1)), (slice(None, None, None), slice(4, 8, 1))], len(scatter_list)=2
_get_tensor: free_dev_ptrs, lidx=0, src=b.safetensors
[DEBUG] ~raw_device_pointer: torch_raw_delete: 0x7c83000, elapsed=0 us
RANK 1: tensor_name=tensor([[ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.],
        [10., 10., 10., 10., 10., 10., 10., 10.],
        [11., 11., 11., 11., 11., 11., 11., 11.],
        [12., 12., 12., 12., 12., 12., 12., 12.],
        [13., 13., 13., 13., 13., 13., 13., 13.],
        [14., 14., 14., 14., 14., 14., 14., 14.],
        [15., 15., 15., 15., 15., 15., 15., 15.]], dtype=torch.float16)
RANK 0: tensor_name=tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4., 4., 4., 4.],
        [5., 5., 5., 5., 5., 5., 5., 5.],
        [6., 6., 6., 6., 6., 6., 6., 6.],
        [7., 7., 7., 7., 7., 7., 7., 7.]], dtype=torch.float16)
RANK 1: tensor_name=tensor([[ 0.,  0.,  0.,  0.],
        [ 1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.],
        [ 3.,  3.,  3.,  3.],
        [ 4.,  4.,  4.,  4.],
        [ 5.,  5.,  5.,  5.],
        [ 6.,  6.,  6.,  6.],
        [ 7.,  7.,  7.,  7.],
        [ 8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.],
        [10., 10., 10., 10.],
        [11., 11., 11., 11.],
        [12., 12., 12., 12.],
        [13., 13., 13., 13.],
        [14., 14., 14., 14.],
        [15., 15., 15., 15.]], dtype=torch.float16)
RANK 0: tensor_name=tensor([[ 0.,  0.,  0.,  0.],
        [ 1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.],
        [ 3.,  3.,  3.,  3.],
        [ 4.,  4.,  4.,  4.],
        [ 5.,  5.,  5.,  5.],
        [ 6.,  6.,  6.,  6.],
        [ 7.,  7.,  7.,  7.],
        [ 8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.],
        [10., 10., 10., 10.],
        [11., 11., 11., 11.],
        [12., 12., 12., 12.],
        [13., 13., 13., 13.],
        [14., 14., 14., 14.],
        [15., 15., 15., 15.]], dtype=torch.float16)
[DEBUG] ~nogds_file_reader: elapsed=53 us
[DEBUG] ~nogds_file_reader: elapsed=110 us

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fastsafetensors-0.1.3.tar.gz (31.8 kB view details)

Uploaded Source

Built Distributions

fastsafetensors-0.1.3-cp311-cp311-manylinux_2_34_x86_64.whl (1.4 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.34+ x86-64

fastsafetensors-0.1.3-cp310-cp310-manylinux_2_34_x86_64.whl (1.3 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.34+ x86-64

File details

Details for the file fastsafetensors-0.1.3.tar.gz.

File metadata

  • Download URL: fastsafetensors-0.1.3.tar.gz
  • Upload date:
  • Size: 31.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.0 CPython/3.10.12

File hashes

Hashes for fastsafetensors-0.1.3.tar.gz
Algorithm Hash digest
SHA256 1c75c573026636da128c2f4c67813b927465a99b4575bf29838c3364ecfaa9d5
MD5 357ea248be694d93e02614e65396c265
BLAKE2b-256 d3e269193a6ff139933d11338943d7af8dc6690fb5676c0ab90c646b28ddba5b

See more details on using hashes here.

File details

Details for the file fastsafetensors-0.1.3-cp311-cp311-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for fastsafetensors-0.1.3-cp311-cp311-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 bf9e643fbb1bda510eac84088f690aa47eb0d0f626791f87eef84da12ced417d
MD5 9c1dabfa374da3d82fe1493af82e20be
BLAKE2b-256 902784469c230ec35bb822650adf6a73b8cf2dfc0484bb0fde682be3cd371bcc

See more details on using hashes here.

File details

Details for the file fastsafetensors-0.1.3-cp310-cp310-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for fastsafetensors-0.1.3-cp310-cp310-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 4007c66bea1281dc95102d27df435403f852f6d958ed0d86208b4eb45e3d0bd8
MD5 643a60bb1af35b719d187d092f0230e0
BLAKE2b-256 cb0705cecacb2835350bb92d529a03cb10b0b3203e8f86475a0d0271c807d95d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page