Gemma JAX 推理

Gemma 是一个基于 Google DeepMind 的 Gemini1 研究和技术的轻量级、最先进的开放大型语言模型家族。本教程演示了如何使用 Google DeepMind 的 gemma 库来执行 Gemma 2B Instruct 模型的基本采样/推理,该库是用 JAX2(一个高性能数值计算库)、Flax3(基于 JAX 的神经网络库)、Orbax4(一个基于 JAX 的用于训练工具如检查点的库)和 SentencePiece5(一个分词器/合词器库)编写的。尽管在这个笔记本中没有直接使用 Flax,但 Flax 被用于创建 Gemma。

设置

1. 设置 Kaggle 上的 Gemma 访问权限

要完成本教程,您首先需要按照 Gemma 设置6说明操作,这些说明将向您展示如何执行以下操作:

  • 在 kaggle.com7 上获取对 Gemma 的访问权限。
  • 选择一个具有足够资源的 Colab 运行环境来运行 Gemma 模型。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,继续进行下一部分,您将为您的 Colab 环境设置环境变量。

2. 设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。在出现“授权访问?”的提示时,同意提供密钥访问。

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("KAGGLE_KEY ")
secret_value_1 = user_secrets.get_secret("KAGGLE_USERNAME")

3. 安装 gemma 库

本笔记本着重于使用免费的 Colab GPU。要启用硬件加速,请点击编辑 > 笔记本设置 > 选择 T4 GPU > 保存。

接下来,您需要从 github.com/google-deepmind/gemma 安装 Google DeepMind 的 gemma 库。如果您遇到关于“pip 的依赖解析器”的错误,通常可以忽略它。

注意:通过安装 gemma,您还将安装 flax、core jax、optax(基于 JAX 的梯度处理和优化库)、orbax 和 sentencepiece。

!pip install -q git+https://github.com/google-deepmind/gemma.git

加载并准备 Gemma 模型

使用 kagglehub.model_download 加载 Gemma 模型,该函数需要三个参数:

  • handle:来自 Kaggle 的模型句柄
  • path:(可选字符串)本地路径
  • force_download:(可选布尔值)强制重新下载模型

注意:gemma-2b-it 模型的大小约为 3.7Gb。

GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /kaggle/input/gemma/flax/2b-it/2

注意:上面输出的路径是模型权重和令牌器在本地保存的地方,之后我们还会需要。

检查模型权重和分词器的位置,然后设置路径变量。分词器目录将位于您下载模型的主目录中,而模型权重将位于一个子目录中。例如:

  • 分词器的 tokenizer.model 文件将位于 /LOCAL/PATH/TO/gemma/flax/2b-it/2
  • 模型的检查点将位于 /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it

在进行下一步之前,请确保正确设置这些路径,以便您的代码能够正确访问和加载模型及其组件。

CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /kaggle/input/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /kaggle/input/gemma/flax/2b-it/2/tokenizer.model

执行抽样/推理

加载和格式化 Gemma 模型检查点,使用 gemma.params.load_and_format_params 方法:

from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)

加载使用 sentencepiece.SentencePieceProcessor 构建的 Gemma 分词器:

import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

要从 Gemma 模型检查点自动加载正确的配置,请使用 gemma.transformer.TransformerConfig。cache_size 参数是 Gemma Transformer 缓存中的时间步数。之后,使用 gemma.transformer.Transformer(继承自 flax.linen.Module)实例化 Gemma 模型为 transformer。

注意:由于当前 Gemma 版本中存在未使用的令牌,词汇量比输入嵌入的数量小。

from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)

创建一个采样器,使用 gemma.sampler.Sampler 在 Gemma 模型的检查点/权重和分词器之上:

from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)

在 input_batch 中编写一个提示,并执行推理。可以调整 total_generation_steps(生成响应时执行的步数 —— 本例中使用100以保存主机内存)。

注意:如果内存不足,请点击 Runtime > Disconnect and delete runtime,然后选择 Runtime > Run all。

prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that

(可选)如果已完成笔记本并想尝试其他提示,运行此单元以释放内存。之后,可以在第三步再次实例化采样器,并在第四步自定义并运行提示。

del sampler

来源

本文翻译自 Gemma JAX inference https://www.kaggle.com/code/windmaple/gemma-jax-inference