Skip to main content

JAX implementation of the Mistral v0.2 base model.

Project description

Mistral 7B v0.2 JAX

This project is the JAX implementation of Mistral 7B v0.2 Base, advancing the work of my earlier repository mistral 7B JAX.

It is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).

Go to Mistral 7B v0.2 JAX Documentation Page.

Roadmap

  • Model architecture
  • Publish a Python library
  • 1D Model parallelism
  • Generation
    • KV cache
    • Left padding
    • Top-K sampling / Top-p / Temperature
    • Beam search
  • Training

Quick Installation

Simple installation from PyPI.

pip install mistral-v0.2-jax

Usage

For usage of the Mistral 7B v0.2 Base JAX model, see the example below::

import jax
import jax.numpy as jnp
from mistral_v0_2.model import convert_mistral_lm_params, forward_mistral_lm, make_rotary_values, shard_mistral_lm_params
from transformers import AutoTokenizer, MistralForCausalLM

model_dir = 'mistral-hf-7B-v0.2'  # convert first with 'Mistral 7B v0.2 Parameter Conversion' part in README
model = MistralForCausalLM.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token

sentences = ['I have a cat.', 'There is a cat in my home.']
inputs = tokenizer(sentences, padding=True, return_tensors='jax')
input_ids = inputs.input_ids
batch_size, batch_len = input_ids.shape
attn_mask = inputs.attention_mask.astype(jnp.bool_)
qk_mask = jnp.tril(jnp.einsum('bi,bj->bij', attn_mask, attn_mask))[:, None, None]
rotary_values = make_rotary_values(batch_size, batch_len)

# load on CPU first to avoid OOM
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
    params = convert_mistral_lm_params(model)
params = shard_mistral_lm_params(params)

logits, kv_cache = forward_mistral_lm(params, input_ids, qk_mask, rotary_values, None)
print(logits)

If you want to generate with this model, you can run it in the terminal:

python generate.py

Install from Source

This project requires Python 3.12, JAX 0.4.26.

Git clone and create venv:

git clone https://github.com/yixiaoer/mistral-v0.2-jax.git
cd mistral-v0.2-jax

python3.12 -m venv venv
. venv/bin/activate

Install dependencies:

CPU:

pip install -U pip
pip install -U wheel
pip install "jax[cpu]"
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install git+https://github.com/huggingface/transformers
pip install -r requirements.txt

CUDA 11:

pip install -U pip
pip install -U wheel
pip install "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
pip install git+https://github.com/huggingface/transformers
pip install -r requirements.txt

TPU VM:

pip install -U pip
pip install -U wheel
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install git+https://github.com/huggingface/transformers
pip install -r requirements.txt

Mistral 7B v0.2 Parameter Conversion

After downloading model v0.2 and tokenizer v0.2, place them together in an input_dir, for example with name mistral-7B-v0.2.

Convert Mistral 7B v0.2 model weight to HuggingFace format by specifying an output_dir in the command, such as mistral-hf-7B-v0.2. (Later, use this directory as model_dir to access the model):

python convert_mistral_weight_to_hf.py --input_dir mistral-7B-v0.2 --model_size 7B --output_dir mistral-hf-7B-v0.2

The architecture of Mistral 7B v0.2 base remains largely consistent with previous versions.

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

The updates include "rope_theta" from 10000.0 to 1000000.0 and "sliding_window" from 4096 to null:

MistralConfig {
  "_name_or_path": "mistral-hf-7B-v0.2",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.37.2",
  "use_cache": true,
  "vocab_size": 32000
}

Problems Encountered

Encountered numerous challenges from the initial Mistral JAX implementation to the present.

Click Problems Part to see more details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mistral_v0_2_jax-0.0.1.tar.gz (16.8 kB view details)

Uploaded Source

Built Distribution

mistral_v0.2_jax-0.0.1-py3-none-any.whl (23.5 kB view details)

Uploaded Python 3

File details

Details for the file mistral_v0_2_jax-0.0.1.tar.gz.

File metadata

  • Download URL: mistral_v0_2_jax-0.0.1.tar.gz
  • Upload date:
  • Size: 16.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for mistral_v0_2_jax-0.0.1.tar.gz
Algorithm Hash digest
SHA256 a0bfe16f91c6efa644c04ffbf1f9cb78d38f21f6558b5b4689ce260dc0ba37fa
MD5 31b35db9d395d2ac9cd5056588f9fd0b
BLAKE2b-256 27e087dfc159b71372781d69e51e8cd674f6028fb15a49a87cffc14a06cfa389

See more details on using hashes here.

File details

Details for the file mistral_v0.2_jax-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for mistral_v0.2_jax-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 95b6c0443deb83ed3d7e15a9ddf78c893c9f55038c287c662ab4afa407651b7c
MD5 622b697f37fe4f609463c2634c5cd503
BLAKE2b-256 e5e8f0cd7c7f54fa2d9db0662cab0c11cfc3bc96ae44628f3737808c7ddb5633

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page