Skip to main content

JAX implementation of the Mistral v0.2 base model.

Project description

Mistral 7B v0.2 JAX

This project is the JAX implementation of Mistral 7B v0.2 Base, advancing the work of my earlier repository mistral 7B JAX.

It is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).

Go to Mistral 7B v0.2 JAX Documentation Page.

Roadmap

  • Model architecture
  • Publish a Python library
  • 1D Model parallelism
  • Generation
    • KV cache
    • Left padding
    • Top-K sampling / Top-p / Temperature
    • Beam search
  • Training

Quick Installation

Simple installation from PyPI.

pip install mistral-v0.2-jax

Usage

For usage of the Mistral 7B v0.2 Base JAX model, see the example below::

import jax
import jax.numpy as jnp
from mistral_v0_2.model import convert_mistral_lm_params, forward_mistral_lm, make_rotary_values, shard_mistral_lm_params
from transformers import AutoTokenizer, MistralForCausalLM

model_dir = 'mistral-hf-7B-v0.2'  # convert first with 'Mistral 7B v0.2 Parameter Conversion' part in README
model = MistralForCausalLM.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token

sentences = ['I have a cat.', 'There is a cat in my home.']
inputs = tokenizer(sentences, padding=True, return_tensors='jax')
input_ids = inputs.input_ids
batch_size, batch_len = input_ids.shape
attn_mask = inputs.attention_mask.astype(jnp.bool_)
qk_mask = jnp.tril(jnp.einsum('bi,bj->bij', attn_mask, attn_mask))[:, None, None]
rotary_values = make_rotary_values(batch_size, batch_len)

# load on CPU first to avoid OOM
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
    params = convert_mistral_lm_params(model)
params = shard_mistral_lm_params(params)

logits, kv_cache = forward_mistral_lm(params, input_ids, qk_mask, rotary_values, None)
print(logits)

If you want to generate with this model, you can run it in the terminal:

python generate.py

Install from Source

This project requires Python 3.12, JAX 0.4.26.

Git clone and create venv:

git clone https://github.com/yixiaoer/mistral-v0.2-jax.git
cd mistral-v0.2-jax

python3.12 -m venv venv
. venv/bin/activate

Install dependencies:

CPU:

pip install -U pip
pip install -U wheel
pip install "jax[cpu]"
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install git+https://github.com/huggingface/transformers
pip install -r requirements.txt

CUDA 11:

pip install -U pip
pip install -U wheel
pip install "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
pip install git+https://github.com/huggingface/transformers
pip install -r requirements.txt

TPU VM:

pip install -U pip
pip install -U wheel
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install git+https://github.com/huggingface/transformers
pip install -r requirements.txt

Mistral 7B v0.2 Parameter Conversion

After downloading model v0.2 and tokenizer v0.2, place them together in an input_dir, for example with name mistral-7B-v0.2.

Convert Mistral 7B v0.2 model weight to HuggingFace format by specifying an output_dir in the command, such as mistral-hf-7B-v0.2. (Later, use this directory as model_dir to access the model):

python convert_mistral_weight_to_hf.py --input_dir mistral-7B-v0.2 --model_size 7B --output_dir mistral-hf-7B-v0.2

The architecture of Mistral 7B v0.2 base remains largely consistent with previous versions.

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

The updates include "rope_theta" from 10000.0 to 1000000.0 and "sliding_window" from 4096 to null:

MistralConfig {
  "_name_or_path": "mistral-hf-7B-v0.2",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.37.2",
  "use_cache": true,
  "vocab_size": 32000
}

Problems Encountered

Encountered numerous challenges from the initial Mistral JAX implementation to the present.

Click Problems Part to see more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mistral_v0_2_jax-0.0.1.tar.gz (16.8 kB view hashes)

Uploaded Source

Built Distribution

mistral_v0.2_jax-0.0.1-py3-none-any.whl (23.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page