在使用 Keras 进行 Gemma 模型的分布式微调和推理

Gemma 是一个轻量级、先进的开放模型家族,由用于创建 Google Gemini 模型的研究和技术构建。Gemma 可以进一步微调以满足特定需求。但是,像 Gemma 这样的大型语言模型可能会非常大,有些可能不适合在单个加速器上进行微调。在这种情况下,有两种一般方法来对它们进行微调:

  1. 参数高效微调(PEFT),旨在通过牺牲一些保真度来缩小有效模型大小。LoRA 属于这一类别,在 Keras 中使用 LoRA 微调 Gemma 模型1的教程演示了如何使用 KerasNLP 在单个 GPU 上使用 LoRA 微调 Gemma 2B 模型 gemma_2b_en。

  2. 使用模型并行性的全参数微调。模型并行性将单个模型的权重分布在多个设备上,并实现水平扩展。您可以在这个 Keras指南2中了解更多关于分布式训练的信息。

本教程指导您如何使用 Keras 与 JAX 后端,通过 LoRA 和模型并行分布式训练在 Google 的张量处理单元(TPU)上微调 Gemma 7B 模型。请注意,在本教程中可以关闭 LoRA,进行较慢但更准确的全参数调整。

使用加速器

技术上,您可以在本教程中使用 TPU 或 GPU。

关于 TPU 环境的说明

Google 提供了3种TPU产品:

  • Colab3 提供的 TPU v2 对于本教程来说不够用。
  • Kaggle4 免费提供 TPU v3,适用于本教程。
  • Cloud TPU5 提供 TPU v3 和更新一代的 TPU。设置它的一种方式是:
    1. 创建一个新的 TPU VM6
    2. 为您打算使用的 Jupyter 服务器端口设置SSH端口转发7
    3. 在 TPU VM 上安装 Jupyter 并启动,然后通过“连接到本地运行时”连接到 Colab

关于多 GPU 设置的说明

虽然本教程侧重于 TPU 的使用案例,但如果您有多 GPU 机器,可以轻松地根据自己的需要调整它。

如果您倾向于通过 Colab 工作,也可以直接通过 Colab 连接菜单中的“连接到自定义 GCE VM”为 Colab 配置多 GPU VM。

我们将重点放在使用 Kaggle 提供的免费 TPU 上。

开始之前

Gemma 设置

要完成本教程,您首先需要按照 Gemma 设置8的说明完成设置。Gemma 设置说明将向您展示如何进行以下操作:

Gemma 模型由 Kaggle 托管。要使用 Gemma,请在 Kaggle 上请求访问权限:

  • 在 kaggle.com9 登录或注册。
  • 打开 Gemma 模型卡片10并选择“请求访问权限”。
  • 完成同意表格并接受条款和条件。

安装

安装 Keras 和 KerasNLP。

# 安装最新的 Keras 3。更多信息查看 https://keras.io/getting_started/。
!pip install -q tensorflow-cpu
!pip install -q -U keras-nlp tensorflow-hub
!pip install -q -U keras>=3
!pip install -U tensorflow-text

设置 Keras JAX 后端

导入 JAX 并在 TPU 上运行健全性检查。 Kaggle 提供 TPUv3-8 设备,该设备具有 8 个 TPU 核心,每个核心有 16GB 内存。

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# Keras 3 分布式 API 目前仅针对 JAX 后端实现
os.environ["KERAS_BACKEND"] = "jax"
# 预分配 90% 的 TPU 内存,以最大程度地减少内存碎片和分配开销
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

加载模型

import keras
import keras_nlp

在 NVIDIA GPU 上进行混合精度训练的说明

在 NVIDIA GPU 上训练时,可以使用混合精度(keras.mixed_precision.set_global_policy('mixed_bfloat16'))来加速训练,对训练质量的影响最小。在大多数情况下,建议开启混合精度,因为它既节省内存又节省时间。但是,请注意,在小批量大小下,它可能会使内存使用量增加1.5倍(权重将被加载两次,一次为半精度,一次为全精度)。

对于推理,半精度(keras.config.set_floatx("bfloat16"))将会起作用并节省内存,而混合精度则不适用。

# 如果您想在 GPU 上启用混合精度训练,请取消注释下面的行
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

要在 TPU上分布式加载模型及其权重和张量,首先需要创建一个新的DeviceMeshDeviceMesh 代表了一组为分布式计算配置的硬件设备,这在 Keras 3 中作为统一分布式API的一部分被引入。

分布式 API 支持数据和模型并行性,允许在多个加速器和主机上高效地扩展深度学习模型。它利用底层框架(例如JAX)根据分片指令通过称为单程序多数据(SPMD)扩展的过程分布程序和张量。欲了解更多详情,请查看新的 Keras 3 分布式API指南11

# 要创建一个形状为(1, 8)的DeviceMesh,以便在所有8个TPUs上分片权重。
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

来自分布式 API 的 LayoutMap 指定了如何使用字符串键对权重和张量进行分片或复制,例如下面的 token_embedding/embeddings,这些键被当作正则表达式来匹配张量路径。匹配到的张量将根据模型维度(8 个 TPU)进行分片;其他的则会被完全复制。

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# 与“token_embedding/embeddings”匹配的权重将在 8 个 TPU 上进行分片
layout_map["token_embedding/embeddings"] = (None, model_dim)
# 用于匹配解码器中的查询、键和值矩阵的正则表达式
# 注意力层
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    None, model_dim, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    None, None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)

ModelParallel 允许您在 DeviceMesh 上的所有设备之间分片模型权重或激活张量。在这种情况下,一些 Gemma 7B 模型的权重根据上面定义的 layout_map 在 8 个 TPU 芯片之间进行了分片。现在以分布式方式加载模型。

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")

现在验证模型是否已正确分区。 我们以decoder_block_1为例。

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

微调前推理

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

该模型生成了 90 年代值得观看的精彩喜剧电影列表。 现在我们微调 Gemma 模型来改变输出风格。

使用 IMDB数据微调

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# 丢弃标签
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
# 使用数据集的子集来加快训练速度。
imdb_train = imdb_train.take(2000)

使用低秩适应(LoRA)12进行微调。LoRA 是一种微调技术,通过冻结模型的全部权重并在模型中插入较少数量的新可训练权重,大大减少了下游任务的可训练参数数量。基本上,LoRA 通过两个较小的低秩矩阵 AxB 进行重新参数化,以进行训练,这种技术使训练更加快速和内存高效。

# 为模型启用 LoRA 并将 LoRA 秩设置为 4。
gemma_lm.backbone.enable_lora(rank=4)
# 对 IMDb 电影评论数据集进行微调。

# 将输入序列长度限制为 128 以控制内存使用。
gemma_lm.preprocessor.sequence_length = 128
# 使用 AdamW(transformer 模型的常见优化器)。
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# 从衰减(decay)中排除 layernorm 和偏置项。
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)

请注意,启用 LoRA 会显着减少可训练参数的数量,从 70 亿个减少到仅 1100 万个。

微调后推理

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1990-1999. 10. Austin Powers - International Man of Mystery (1997) 9. The Wedding Singer (1998) 8. The Cable Guy (1996) 7'

经过微调后,该模型已经了解了电影评论的风格,并且不会在 90 年代喜剧电影的背景下生成该风格的输出。

来源

本文翻译自 Keras Gemma distributed finetuning and inference https://www.kaggle.com/code/nilaychauhan/keras-gemma-distributed-finetuning-and-inference