在 Keras 中使用 LoRA 微调 Gemma 模型

大型语言模型(LLM)如 Gemma 已被证明在多种自然语言处理(NLP)任务上有效。LLM首先通过自监督方式在大量文本语料上进行预训练。预训练帮助 LLM 学习通用知识,例如单词之间的统计关系。然后,可以使用特定领域的数据对 LLM 进行微调,以执行下游任务(如情感分析)。

LLM 的大小极大(参数数量达到数百万)。对于大多数应用来说,不需要进行完全微调(更新模型中的所有参数),因为典型的微调数据集相对于预训练数据集要小得多。

低秩适应(LoRA)1是一种微调技术,通过冻结模型的权重并在模型中插入较少数量的新权重,大大减少了下游任务的可训练参数数量。这使得使用 LoRA 进行训练更快、更节省内存,并且生成的模型权重更小(几百MB),同时保持了模型输出的质量。

本教程将引导您使用 KerasNLP 对 Gemma 2B 模型进行 LoRA 微调,使用的是 Databricks Dolly 15k数据集2。该数据集包含15,000个高质量的人类生成的提示/响应对,专门用于微调 LLM。

设置

Gemma 设置

要完成本教程,您首先需要按照 Gemma 设置3的说明完成设置。Gemma 设置说明将向您展示如何进行以下操作:

Gemma 模型由 Kaggle 托管。要使用 Gemma,请在 Kaggle 上请求访问权限:

  • 在 kaggle.com4 登录或注册。
  • 打开 Gemma 模型卡片5并选择“请求访问权限”。
  • 完成同意表格并接受条款和条件。

安装依赖

安装 Keras 、 KerasNLP 和其他依赖。

# 安装最新的 Keras 3。更多信息查看 https://keras.io/getting_started/。

!pip install -q -U keras-nlp
!pip install -q -U keras>=3

选择一个后端

Keras 是一个高级的、多框架的深度学习API,设计上注重简单性和易用性。Keras 3 允许您选择后端:TensorFlow、JAX或 PyTorch。这三个后端对于本教程都适用。

在本教程中,我们使用 JAX 作为后端。

import os

os.environ["KERAS_BACKEND"] = "jax"  # 或者 "tensorflow" 、 "torch"。
# 在使用JAX后端时,避免内存碎片化。
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

导入包

导入 Keras 和 KerasNLP。

import keras
import keras_nlp

加载数据集

预处理数据是微调模型的重要步骤,尤其是当使用大型语言模型时。本教程使用的是1000个训练示例的子集,以便更快地执行。如果想要获得更高质量的微调效果,建议使用更多的训练数据。

import json
data = []
with open('/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl') as file:
    for line in file:
        features = json.loads(line)
        # 过滤掉带有上下文的示例,以保持简单。
        if features["context"]:
            continue
        # 将整个示例格式化为单个字符串。
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# 仅使用 1000 个训练示例,以保持快速。
data = data[:1000]

加载模型

KerasNLP 提供了许多流行模型架构6的实现。在本教程中,您将使用 GemmaCausalLM 创建一个模型,这是一个用于因果语言建模的端到端 Gemma 模型。因果语言模型基于前面的令牌预测下一个令牌。

使用 from_preset 方法创建模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor"
Tokenizer (type)Vocab #
gemma_tokenizer (GemmaTokenizer)256,000
Model: "gemma_causal_lm"
Layer (type)Output ShapeParam #Connected to
padding_mask (InputLayer)(None, None)0-
token_ids (InputLayer)(None, None)0-
gemma_backbone (GemmaBackbone)(None, None, 2048)2,506,172,416padding_mask[0][0], token_ids[0][0]
token_embedding (ReversibleEmbedding)(None, None, 256000)524,288,000gemma_backbone[0][0]
 Total params: 2,506,172,416 (9.34 GB)
 Trainable params: 2,506,172,416 (9.34 GB)
 Non-trainable params: 0 (0.00 B)

from_preset 方法从预设的架构和权重中实例化模型。在上述代码中,字符串 “gemma_2b_en” 指定了预设的架构 —— 一个拥有 20 亿参数的 Gemma 模型。

注意:Gemma 也提供了一个有 70 亿参数的模型。要在 Colab 中运行更大的模型,您需要访问付费计划中提供的高级 GPU。或者,您可以在 Kaggle 或 Google Cloud 上对 Gemma 7B 模型进行分布式调优7

在微调之前的推理

在本节中,您将用各种提示查询模型,以查看其如何响应。

欧洲旅行提示

查询模型以获取关于欧洲旅行应做些什么的建议。

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
1. Take a trip to Europe.
2. Take a trip to Europe.
3. Take a trip to Europe.
4. Take a trip to Europe.
5. Take a trip to Europe.
6. Take a trip to Europe.
7. Take a trip to Europe.
8. Take a trip to Europe.
9. Take a trip to Europe.
10. Take a trip to Europe.
11. Take a trip to Europe.
12. Take a trip to Europe.
13. Take a trip to Europe.
14. Take a trip to Europe.
15. Take a trip to Europe.
16. Take a trip to Europe.
17. Take a trip to Europe.
18. Take a trip to Europe.
19. Take a trip to Europe.
20. Take a trip to Europe.
21. Take a trip to Europe.
22. Take a trip to Europe.
23. Take a trip to Europe.
24. Take a trip to Europe.
25. Take a trip to

该模型只是重复打印“Take a trip to Europe”。

ELI5 光合作用提示

提示模型用 5 岁儿童能够理解的简单术语解释光合作用。

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from the light is used to split water molecules into hydrogen and oxygen. The oxygen is released into the atmosphere, while the hydrogen is used to make glucose. The glucose is then used by the plant to make energy and grow.

Explanation:
Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from the light is used to split water molecules into hydrogen and oxygen. The oxygen is released into the atmosphere, while the hydrogen is used to make glucose. The glucose is then used by the plant to make energy and grow.

Explanation:

Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from

回答中包含对儿童来说可能不容易理解的单词,例如叶绿素、葡萄糖等。

LoRA 微调

要从模型中获得更好的响应,可以使用 Databricks Dolly 15k 数据集通过低秩适应(LoRA)对模型进行微调。

LoRA 秩决定了添加到 LLM 原始权重中的可训练矩阵的维度。它控制着微调调整的表达性和精度。

更高的秩意味着可以进行更详细的更改,但也意味着有更多的可训练参数。较低的秩意味着计算开销较小,但可能导致适应性不够精确。

本教程使用的 LoRA 秩为 4。在实践中,从相对较小的秩开始(例如 4、8、16)是计算上高效的试验方法。使用这个秩训练您的模型,并评估在您的任务上的性能改进。逐渐增加后续试验的秩,看看是否能进一步提高性能。

# 为模型启用 LoRA 并将 LoRA 秩设置为 4。
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor"
Tokenizer (type)Vocab #
gemma_tokenizer (GemmaTokenizer)256,000
Model: "gemma_causal_lm"
Layer (type)Output ShapeParam #Connected to
padding_mask (InputLayer)(None, None)0-
token_ids (InputLayer)(None, None)0-
gemma_backbone (GemmaBackbone)(None, None, 2048)2,507,536,384padding_mask[0][0], token_ids[0][0]
token_embedding (ReversibleEmbedding)(None, None, 256000)524,288,000gemma_backbone[0][0]
 Total params: 2,507,536,384 (9.34 GB)
 Trainable params: 1,363,968 (5.20 MB)
 Non-trainable params: 2,506,172,416 (9.34 GB)

请注意,启用 LoRA 会显着减少可训练参数的数量(从 25 亿减少到 130 万)。

# 将输入序列长度限制为 512(以控制内存使用)。
gemma_lm.preprocessor.sequence_length = 512
# 使用 AdamW(transformer 模型的常见优化器)。
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# 从衰减(decay)中排除 layernorm 和偏置项。
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

微调之后的推理

微调后,模型的响应会遵循提示中提供的指令。

欧洲旅行提示

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
You should plan to see the most famous sights in Europe. The Eiffel Tower, the Acropolis, and the Colosseum are just a few. You should also plan on seeing as many countries as possible. There are so many amazing places in Europe, it is a shame to not see them all.

Additional Information:
Europe is a very interesting place to visit for many reasons, not least of which is that there are so many different places to see.

微调后的模型现在可以推荐在欧洲访问的地方了。

ELI5 光合作用提示

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is a process in which plants and photosynthetic organisms (such as algae, cyanobacteria, and some bacteria and archaea) use light energy to convert water and carbon dioxide into sugar and release oxygen. This process requires chlorophyll, water, carbon dioxide, and energy. The chlorophyll captures the light energy and uses it to power a reaction that converts the carbon from carbon dioxide into organic molecules (such as sugar) that can be used for energy. The process also generates oxygen as a by-product.

该模型现在用简单的术语解释了光合作用。

请注意,出于演示目的,本教程仅在数据集的小子集上对模型进行了一次迭代(epoch)的微调,并且使用了较低的 LoRA 秩值。要从微调后的模型中获得更好的响应,您可以尝试:

  1. 增加微调数据集的大小。
  2. 增加训练步骤(迭代次数)。
  3. 设置更高的 LoRA 秩。
  4. 修改超参数值,如学习率(learning_rate)和权重衰减(weight_decay)。

来源

本文翻译自 Fine-tune Gemma models in Keras using LoRA https://www.kaggle.com/code/nilaychauhan/fine-tune-gemma-models-in-keras-using-lora