A Python package for estimating GPU memory requirements and the number of GPUs needed for training machine learning models
Project description
GPU Estimator
A Python package for estimating GPU memory requirements and the number of GPUs needed for training machine learning models.
Features
- Estimate GPU memory requirements based on model parameters
- Calculate optimal number of GPUs for training
- Support for different precision types (FP32, FP16, BF16, INT8)
- Account for optimizer states and gradient storage
- Integration with Hugging Face Hub for latest models
- Discover and search trending models
- Support for popular architectures (GPT, LLaMA, BERT, T5, Mistral, etc.)
- CLI interface for quick estimates
- Detailed memory breakdown and recommendations
Installation
pip install gpu-estimator
Quick Start
Basic Usage
from gpu_estimator import GPUEstimator
estimator = GPUEstimator()
# Estimate for a 7B parameter model
result = estimator.estimate(
model_params=7e9,
batch_size=32,
sequence_length=2048,
precision="fp16"
)
print(f"Memory needed per GPU: {result.memory_per_gpu_gb:.2f} GB")
print(f"Recommended GPUs: {result.num_gpus}")
Hugging Face Integration
from gpu_estimator import GPUEstimator
estimator = GPUEstimator()
# Estimate directly from Hugging Face model ID
result = estimator.estimate_from_huggingface(
model_id="meta-llama/Llama-2-7b-hf",
batch_size=4,
sequence_length=2048,
precision="fp16",
gradient_checkpointing=True
)
print(f"Total memory required: {result.total_memory_gb:.2f} GB")
print(f"GPUs needed: {result.num_gpus}")
# Discover trending models
trending = estimator.list_trending_models(limit=10, task="text-generation")
for model in trending:
print(f"{model.model_id} - {model.downloads:,} downloads")
# Search for specific models
models = estimator.search_models("mistral", limit=5)
for model in models:
print(f"{model.model_id} - {model.architecture}")
CLI Usage
Basic Estimation
# Estimate for any model by parameters
gpu-estimate estimate --model-params 7e9 --batch-size 4 --precision fp16
# Estimate for predefined models
gpu-estimate estimate --model-name llama-7b --batch-size 8
# Estimate for Hugging Face models
gpu-estimate estimate --huggingface-model meta-llama/Llama-2-7b-hf --batch-size 4
Model Discovery
# List trending models
gpu-estimate trending --limit 20 --task text-generation
# Search for models
gpu-estimate search "mistral" --limit 10
# Get popular models by architecture
gpu-estimate popular llama --limit 5
# Get model information
gpu-estimate info llama-7b
Advanced Options
# With gradient checkpointing and specific GPU
gpu-estimate estimate \
--huggingface-model microsoft/DialoGPT-large \
--batch-size 8 \
--seq-length 1024 \
--precision fp16 \
--gpu-type A100 \
--gradient-checkpointing \
--verbose
Interactive Mode
Launch an interactive session for guided GPU estimation:
gpu-estimate interactive
Features:
- Guided workflows for all estimation tasks
- Model discovery with direct estimation
- Flexible model specification (parameters, names, or HF IDs)
- Step-by-step configuration of training parameters
- Quick estimates from trending model lists
Supported Models & Architectures
Hugging Face Models
The package automatically supports any model on Hugging Face Hub by detecting their configuration. Popular architectures include:
| Architecture | Examples | Use Cases |
|---|---|---|
| LLaMA/LLaMA2 | meta-llama/Llama-2-7b-hf, meta-llama/Llama-2-13b-hf |
General language modeling, chat |
| GPT | gpt2, microsoft/DialoGPT-large |
Text generation, conversation |
| Mistral | mistralai/Mistral-7B-v0.1 |
Efficient language modeling |
| CodeLlama | codellama/CodeLlama-7b-Python-hf |
Code generation |
| BERT | google-bert/bert-base-uncased |
Text classification, NLU |
| T5 | google-t5/t5-base, google/flan-t5-large |
Text-to-text tasks |
| Phi | microsoft/phi-2 |
Small efficient models |
| Gemma | google/gemma-7b |
Google's language models |
| Qwen | Qwen/Qwen-7B |
Multilingual models |
Predefined Models
Classic models with known configurations:
- GPT Family:
gpt2,gpt2-medium,gpt2-large,gpt2-xl,gpt3 - LLaMA Family:
llama-7b,llama-13b,llama-30b,llama-65b - LLaMA 2:
llama2-7b,llama2-13b,llama2-70b - Code LLaMA:
codellama-7b,codellama-13b,codellama-34b - Mistral:
mistral-7b - Phi:
phi-1.5b,phi-2.7b - Gemma:
gemma-2b,gemma-7b
Flexible Naming: Model names support flexible matching. Use custom-llama-7b, my-mistral-7b, or any name containing a known model identifier.
GPU Types Supported
| GPU | Memory | Use Case |
|---|---|---|
| H100 | 80 GB | Latest high-performance training |
| A100 | 80 GB | Large model training and inference |
| A40 | 48 GB | Professional workstation training |
| A6000 | 48 GB | Creative and AI workstation |
| L40 | 48 GB | Data center inference |
| L4 | 24 GB | Efficient inference |
| RTX 4090 | 24 GB | Consumer high-end |
| RTX 3090 | 24 GB | Consumer enthusiast |
| V100 | 32 GB | Previous generation training |
| T4 | 16 GB | Cloud inference |
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gpu_estimator-0.1.2.tar.gz.
File metadata
- Download URL: gpu_estimator-0.1.2.tar.gz
- Upload date:
- Size: 19.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
28cc19974fd9c05f3466155811b2e1479c7cf0fc41fb69a891dff926c3433cf6
|
|
| MD5 |
c506a02c83ccb97028a1d815ccd7cf75
|
|
| BLAKE2b-256 |
e8da4b8c53da4c5e59999f278d63317b01f1b2f5a8a3d3dd13c8c7d88bdf7cb4
|
File details
Details for the file gpu_estimator-0.1.2-py3-none-any.whl.
File metadata
- Download URL: gpu_estimator-0.1.2-py3-none-any.whl
- Upload date:
- Size: 17.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
97355c5439a6959a0d88d4051a37a393053c98efd923d6a4c0881f78a04f6f93
|
|
| MD5 |
ea6b98327a98c3dbf78307c77edb5e0a
|
|
| BLAKE2b-256 |
22757c86dd387d58df57b7851415441ee7811a2ac521292b6997d019d382857a
|