Skip to main content

A simple wrapper around DeepSpeed for model parallelism.

Project description

Model Parallelism

This package is a simple wrapper around DeepSpeed to make it as easy as possible to implement model parallelism in your PyTorch models.

Example Usage

  # Your training script
+ import model_parallelism

  # All your data preparation, logging, etc.

  model = create_model(...)

- model = model.to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
+ model = model_parallelism.initialize(
+     model, learning_rate=1e-4, optimizer="Adam", batch_size=batch_size
+ )

  for batch in dataloader:
      loss = model(batch)
-     loss.backward
+     model.backward(loss)
-     optimizer.step()
+     model.step()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

model_parallelism-0.1.0.tar.gz (2.8 kB view details)

Uploaded Source

Built Distribution

model_parallelism-0.1.0-py3-none-any.whl (3.2 kB view details)

Uploaded Python 3

File details

Details for the file model_parallelism-0.1.0.tar.gz.

File metadata

  • Download URL: model_parallelism-0.1.0.tar.gz
  • Upload date:
  • Size: 2.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for model_parallelism-0.1.0.tar.gz
Algorithm Hash digest
SHA256 a8b0e0a17d367882b18e741627ac1111569959335da794092a6a20701562551a
MD5 2155568b1f52846c01bee7fdba29eff5
BLAKE2b-256 50f6a6c229c9464c277c454664bf8d5c9985375d662a8137d9647e919b4565c0

See more details on using hashes here.

File details

Details for the file model_parallelism-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: model_parallelism-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 3.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.8.2 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for model_parallelism-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 880dd831f4f643de34d3aa45a4443a52d4c27adb6bad6634c99a06ecd4340b9f
MD5 3f800594e831a0446274f2eb5312a098
BLAKE2b-256 fc87744e42608fe0e4abdfe7feccadd42c36f40f23a68bf208cb5970aa0f65d0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page