Skip to main content

Native PyTorch CUDA IPC over Unix Domain Socket for same-host process separation

Project description

Shared Tensor

shared_tensor is a narrow library for one job: sharing CUDA torch.Tensor and CUDA torch.nn.Module objects across processes on the same host and the same GPU with native PyTorch IPC semantics.

The control plane is a local Unix Domain Socket RPC channel. The data plane is native torch CUDA IPC serialization. CPU fallback is intentionally out of scope.

Scope

Supported:

  • same-host trusted processes
  • same-GPU CUDA tensors and modules
  • explicit endpoint registration
  • sync call and task-backed submit
  • managed object handles with explicit release
  • server-side caching, cache_format_key, and singleflight
  • manual two-process deployment as the primary production path
  • zero-branch auto mode gated by SHARED_TENSOR_ENABLED=1

Not supported:

  • CPU tensor or CPU module transport
  • generic Python object RPC
  • cross-host transport
  • mps
  • implicit device migration

Install

Use Python 3.9+ and a CUDA-enabled PyTorch build.

pip install shared-tensor

For local development:

conda create -y -n shared-tensor-dev python=3.11
conda activate shared-tensor-dev
pip install -e ".[dev,test]"

Docs

Read the examples first, then the design notes:

  • docs/overview.md
  • docs/patterns.md
  • docs/architecture.md
  • docs/lifecycle.md
  • docs/diagrams.md

Example: Manual Two-Process Deployment

Production should prefer two explicitly started processes: one server process that owns CUDA objects, and one or more client processes that reopen them through torch IPC.

See examples/model_service.py for endpoint definitions.

Server process:

from shared_tensor import SharedTensorProvider, SharedTensorServer

provider = SharedTensorProvider(execution_mode="server")

@provider.share(execution="task", managed=True, concurrency="serialized", cache_format_key="model:{hidden_size}")
def load_model(hidden_size: int = 4):
    ...

server = SharedTensorServer(provider)
server.start(blocking=True)

Client process:

import torch

from shared_tensor import SharedObjectHandle, SharedTensorClient

client = SharedTensorClient()
x = torch.ones(1, 4, device="cuda")
result = client.call("load_model", hidden_size=4)
if isinstance(result, SharedObjectHandle):
    with result as handle:
        y = handle.value(x)

This keeps the contract explicit:

server process                      client process
------------------------------      ------------------------------
owns CUDA allocations               issues local UDS RPC requests
executes endpoint functions         reopens CUDA objects via torch IPC
manages cache and refcounts         releases managed handles explicitly

Example: Same Code, Two Processes

See examples/zero_branch_env.py. This is a convenience mode for environments that want one file and environment-controlled behavior.

SHARED_TENSOR_ENABLED=1 SHARED_TENSOR_ROLE=server python demo.py
SHARED_TENSOR_ENABLED=1 python demo.py

What changes is only the environment:

same code

server process                      client process
------------------------------      ------------------------------
provider auto-starts local thread   provider builds client wrappers
shared function runs locally        shared function becomes RPC call
CUDA object stays on same GPU       CUDA object is reopened via torch IPC

Example: Reusable Model Registry

See examples/model_service.py.

@provider.share(
    execution="task",
    managed=True,
    concurrency="serialized",
    cache_format_key="model:{input_dim}:{output_dim}",
)
def load_linear_model(input_dim: int = 16, output_dim: int = 4) -> torch.nn.Module:
    ...

Recommended settings for expensive reusable models:

  • execution="task"
  • managed=True
  • concurrency="serialized"
  • singleflight=True
  • explicit cache_format_key

This gives one build per cache key, shared handles for identical requests, and explicit release semantics. Task submission uses the same server-side cache as sync call: repeated submit for the same cache key reuses the cached result instead of rebuilding the CUDA object.

Example: Direct Tensor Path

See examples/basic_service.py.

@provider.share(execution="direct", cache=False)
def echo_tensor(tensor: torch.Tensor) -> torch.Tensor:
    return tensor

Use this for short-lived request-scoped CUDA transforms. The main production path is still task-backed model construction.

Configuration

SharedTensorProvider() defaults to safe local mode unless shared-tensor behavior is explicitly enabled.

Environment gate:

export SHARED_TENSOR_ENABLED=1

Per-provider override:

SharedTensorProvider(enabled=True)
SharedTensorProvider(enabled=False)
SharedTensorProvider(enabled=None)

Provider runtime controls:

SharedTensorProvider(server_startup_timeout=30.0)
provider.get_runtime_info()

Non-blocking provider autostart runs the UDS server in a background thread inside the current process.

execution_mode="auto" behaves as follows:

  • disabled: local mode
  • enabled + SHARED_TENSOR_ROLE=server: auto-start a local background server thread and execute endpoints locally
  • enabled + role unset: build client wrappers

For production deployment, prefer explicit SharedTensorServer(...).start(blocking=True) in a dedicated server process.

Socket selection is per CUDA device:

  • base path comes from SHARED_TENSOR_BASE_PATH or /tmp/shared-tensor
  • runtime socket path is <base_path>-<device_index>.sock
  • device_index=None means probe lazily from the current CUDA device when needed

Payload Contract

Allowed result payloads:

  • CUDA torch.Tensor
  • CUDA torch.nn.Module

Allowed call payloads:

  • CUDA tensors and modules
  • scalar control values in args and kwargs
  • tuple, list, and dict[str, ...] wrappers
  • empty args and kwargs through the control path

Rejected:

  • CPU tensors or modules
  • plain Python result payloads
  • mps

Managed Objects

When managed=True, the client receives a SharedObjectHandle.

handle = load_model(hidden_size=4096)
with handle as model_handle:
    y = model_handle.value(x)

You can also release explicitly:

handle.release()

Use managed mode for cached models or other reusable long-lived CUDA objects.

Runtime Introspection

client.get_server_info() now returns readiness and process metadata in addition to endpoint and capability data. In client mode, provider.get_runtime_info() wraps that into a provider-oriented view.

info = provider.get_runtime_info()
# execution_mode, server_socket_path, server_running, server_ready, server_info...

Testing

Default suite:

python -m pytest -m "not gpu"

GPU suite:

python -m pytest -m gpu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

shared_tensor-0.2.8.tar.gz (28.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

shared_tensor-0.2.8-py3-none-any.whl (32.4 kB view details)

Uploaded Python 3

File details

Details for the file shared_tensor-0.2.8.tar.gz.

File metadata

  • Download URL: shared_tensor-0.2.8.tar.gz
  • Upload date:
  • Size: 28.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for shared_tensor-0.2.8.tar.gz
Algorithm Hash digest
SHA256 c2cdc7256d97eddea055cc7d70d1e1520b54828bcd4c03de3176eab1c2f3b55d
MD5 1bc1df61d8cdc6c37aa3de46015b456a
BLAKE2b-256 ab96a2bb706aba8eec809c4c17ca2fdab262c2084ac2dcc80b8ca151c605e659

See more details on using hashes here.

File details

Details for the file shared_tensor-0.2.8-py3-none-any.whl.

File metadata

  • Download URL: shared_tensor-0.2.8-py3-none-any.whl
  • Upload date:
  • Size: 32.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.8

File hashes

Hashes for shared_tensor-0.2.8-py3-none-any.whl
Algorithm Hash digest
SHA256 2f8b95fd15d160959756d9217e6907b54ab4158e886408e077fcf8192d72a91d
MD5 d976af90e68efcd375b32231e24555c3
BLAKE2b-256 2fede546416e83e292485e00fb5bd6386e6e322d784e0b1ce9cca40a40f5a7a0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page