A simple wrapper around DeepSpeed for model parallelism.
Project description
Model Parallelism
This package is a simple wrapper around DeepSpeed to make it as easy as possible to implement model parallelism in your PyTorch models.
Example Usage
# Your training script
+ import model_parallelism
# All your data preparation, logging, etc.
model = create_model(...)
- model = model.to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
+ model = model_parallelism.initialize(
+ model, learning_rate=1e-4, optimizer="Adam", batch_size=batch_size
+ )
for batch in dataloader:
loss = model(batch)
- loss.backward
+ model.backward(loss)
- optimizer.step()
+ model.step()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for model_parallelism-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 880dd831f4f643de34d3aa45a4443a52d4c27adb6bad6634c99a06ecd4340b9f |
|
MD5 | 3f800594e831a0446274f2eb5312a098 |
|
BLAKE2b-256 | fc87744e42608fe0e4abdfe7feccadd42c36f40f23a68bf208cb5970aa0f65d0 |