Megatron LM 11B on Huggingface Transformers
Project description
Megatron 11B
- Porting of Megatron LM 11B model published on facebook on Huggingface Transformers.
- This repo contains the model's code, checkpoints and parallelization examples.
Installation
pip install megatron-11b
Usage
1. Tokenizer
- The usage of tokenizer is the same as other tokenizers of the existing Huggingface.
- BOS and EOS token are automatically attached, so if you want to use it as a prompt, please exclude EOS token (using
[:-1]
)
from megatron_11b import MegatronTokenizer
tokenizer = MegatronTokenizer.from_pretrained("hyunwoongko/megatron-11B")
tokens = tokenizer.encode("Kevin is")
# [0, 21910, 16] ---> include EOS
tokens = tokenizer.encode("Kevin is")[:, :-1]
# [0, 21910, 16, 2] ---> exclude EOS
2. Model
- We currently support the CausalLM model and the SequenceClassification model.
- The usage of model is also the same as other models of the existing Huggingface.
from megatron_11b import MegatronForCausalLM, MegatronForSequenceClassification
model_clm = MegatronForCausalLM.from_pretrained("hyunwoongko/megatron-11B")
model_clf = MegatronForSequenceClassification.from_pretrained("hyunwoongko/megatron-11B")
3. Generation
from megatron_11b import MegatronForCausalLM, MegatronTokenizer
tokenizer = MegatronTokenizer.from_pretrained("hyunwoongko/megatron-11B")
model = MegatronForCausalLM.from_pretrained("hyunwoongko/megatron-11B").half().cuda()
inputs = "Kevin is"
inputs = tokenizer.encode(inputs, return_tensors="pt").cuda()[:, :-1] # exclude EOS
output = model.generate(inputs, num_beams=5, no_repeat_ngram_size=4, repetition_penalty=1.2)
print(tokenizer.batch_decode(output))
- output of generation.
<s>Kevin is a great guy.</s>
4. Model Parallelism
- Currently, I'm preparing an opensource called
Parallelformers
that can parallelize all models of Huggingface Transformers. - I plan to support model parallelization through this library. (maybe I can release it next month)
- The relevant code can be found via
MegatronPolicy
object below.
from parallelformers.polices.base import Policy, Layer
from parallelformers.utils.dist_utils import AllReduceLinear
from megatron_11b.modeling_megatron import MegatronDecoderLayer
class MegatronPolicy(Policy):
@staticmethod
def replace_arguments(config, world_size):
return {
# 1. reduce hidden size
"self_attn.embed_dim": config.d_model // world_size,
# 2. reduce number of heads
"self_attn.num_heads": config.encoder_attention_heads // world_size,
}
@staticmethod
def attn_qkv():
return [
Layer(
weight="self_attn.q_proj.weight",
bias="self_attn.q_proj.bias",
),
Layer(
weight="self_attn.k_proj.weight",
bias="self_attn.k_proj.bias",
),
Layer(
weight="self_attn.v_proj.weight",
bias="self_attn.v_proj.bias",
),
]
@staticmethod
def attn_out():
return [
Layer(
weight="self_attn.out_proj.weight",
bias="self_attn.out_proj.bias",
replace=AllReduceLinear,
),
]
@staticmethod
def mlp_in():
return [
Layer(
weight="fc1.weight",
bias="fc1.bias",
),
]
@staticmethod
def mlp_out():
return [
Layer(
weight="fc2.weight",
bias="fc2.bias",
replace=AllReduceLinear,
),
]
@staticmethod
def original_layer_class():
return MegatronDecoderLayer
References
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
No source distribution files available for this release.See tutorial on generating distribution archives.
Built Distribution
File details
Details for the file megatron_11b-1.0-py3-none-any.whl
.
File metadata
- Download URL: megatron_11b-1.0-py3-none-any.whl
- Upload date:
- Size: 22.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.49.0 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b9685ff4642cd88230d20d80b6046ac2d2b13fb239d32e94f597348c23a1280e |
|
MD5 | 6a8229c58d7eff375f03f51536d4325a |
|
BLAKE2b-256 | 092f5f93e519f471f9a0083a196b66c9ecdc69f347f1e8b6514fb2f4a68ecad0 |