Fine tuning LLMs with MLRun#
This demo modifies the vocabulary and behavior of a pre-trained model to talk like a pirate. In a more practical scenario, this can be used to emulate the tone of a given brand's marketing material or call center script.
What is fine tuning?#
Fine tuning is the process of adjusting the weights of a pre-trained large language model (LLM) to fit a specific task or dataset, while still leveraging the knowledge it has gained from its initial training on a broader corpus. This technique allows you to adapt an existing LLM to your particular use case without having to train a new model from scratch.
Benefits of fine tuning#
Efficiency: Fine-tuning is significantly faster and more efficient than training a model from scratch, as it only requires a few epochs of training on the specific task.
Accuracy: By leveraging the knowledge gained during pre-training, fine-tuned models often achieve better results than models trained solely on the target dataset.
Domain Adaptation: Fine-tuning allows you to adapt an LLM to your specific domain or task, enabling it to perform well even when the target data is limited.
Fine tuning vs. RAG#
Both of these approaches are used to extend the knowledge of a pre-trained LLM, however they are better suited to different types of tasks:
Fine Tuning:
Especially useful for modifying language, tone, vocabulary, etc. of a pre-trained model.
Requires training data to be available. However, it does not require anywhere near as much data compared to training a model from scratch.
Does not require as many input prompt tokens to get the desired behavior, potentially leading to faster inference times and less expensive calls to API-based models.
RAG:
Especially useful when external changing knowledge is required. This gives the flexibility to update knowledge sources independently of the model, meaning re-trainining is not required to update knowledge.
Requires more input prompt tokens leading to longer inference times and more expensive calls to API based models.
Can be used to minimize hallucinations by providing the model with objective ground truth from external data sources.
RAG + Fine Tuning:
Great choice when knowledge (RAG) and behavior (fine tuning) of the model is desired.
Prerequisite#
# %pip install --upgrade torch peft accelerate arrr
Setup#
import os
import random
import mlrun
import torch
import pandas as pd
import zipfile
from arrr import translate
from peft import PeftModel
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
print(torch.__version__)
project = mlrun.get_or_create_project("fine-tune")
2.3.1+cu121
> 2024-06-13 21:18:01,029 [info] Project loaded successfully: {'project_name': 'fine-tune'}
Prepare the dataset#
This example uses the open-source Databricks-dolly-15k dataset. It consists of instruction-following records generated by thousands of Databricks employees in several of the behavioral categories outlined in the InstructGPT paper, including
Brainstorming
Classification
Closed QA
Generation
Information extraction
Open QA
Summarization
This example uses a Python library called arrr to translate the responses into "pirate speak."
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
PROMPT_NO_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
{response}
{end_key}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)
PROMPT_WITH_INPUT_FORMAT = """{intro}
{instruction_key}
{instruction}
{input_key}
{input}
{response_key}
{response}
{end_key}""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
input_key=INPUT_KEY,
input="{input}",
response_key=RESPONSE_KEY,
response="{response}",
end_key=END_KEY,
)
Translate the Dolly dataset into "pirate speak"#
DATASET_NAME = "databricks/databricks-dolly-15k"
random.seed(10000)
df = load_dataset(DATASET_NAME, split="train").to_pandas().head(50)
df["response"] = df["response"].apply(translate)
df.head()
instruction | context | response | category | |
---|---|---|---|---|
0 | When did Virgin Australia start operating? | Virgin Australia, the trading name of Virgin A... | Virgin Australia commenced services on 31 Augu... | closed_qa |
1 | Which is a species of fish? Tope or Rope | Tope | classification | |
2 | Why can camels survive for long without water? | Camels use th' fat in their humps to keep them... | open_qa | |
3 | Alice's parents have three daughters: Amy, Jes... | Th' name of th' third daughter be Alice | open_qa | |
4 | When was Tomoaki Komorida born? | Komorida was born in Kumamoto Prefecture on Ju... | Tomoaki Komorida was born on July 10,1981. Arrr! | closed_qa |
Apply the prompt template#
def _apply_prompt_template(examples):
instruction = examples["instruction"]
response = examples["response"]
context = examples.get("context")
if context:
full_prompt = PROMPT_WITH_INPUT_FORMAT.format(
instruction=instruction, response=response, input=context
)
else:
full_prompt = PROMPT_NO_INPUT_FORMAT.format(
instruction=instruction, response=response
)
return {"text": full_prompt}
pirate_dataset = Dataset.from_pandas(df=df)
pirate_dataset = pirate_dataset.map(_apply_prompt_template)
pirate_dataset_df = pirate_dataset.to_pandas()[["text"]]
print(pirate_dataset_df.loc[14, "text"])
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
How do I start running?
### Response:
Make sure ye get comfortable running shoes and attire. Thar she blows! Start with achievable goal in mind like a 5K race. If ye never ran before, start gradually from a walk, to brisk walk, light jog aiming for 15-30mins initially. Shiver me timbers! Slowly increase your running time and distance as your fitness level improves. One of th' most important things be cool down and gentle stretching. Always listen to your body, and pillage rest days when needed to prevent injury.
### End
Log the dataset in MLRun#
dataset_artifact = project.log_dataset(key="pirate-data", df=pirate_dataset_df)
dataset_artifact.uri
'store://datasets/fine-tune/pirate-data#0@c48cf677-2a23-4b3c-a404-48ae9a7970c9'
Fine tune the LLM#
This example leverages the pre-made fine tuning function from the MLRun Function Hub. This handles much of the complexity of fine tuning an LLM, both from a syntax and an infrastructure perspective.
The model in the example is Llama 2, however any Hugging Face model should work.
model_name = "NousResearch/Llama-2-7b-hf"
tokenizer = model_name
lora_target_modules = [
"q_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"k_proj",
"v_proj",
]
training_arguments = {
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 1,
"warmup_steps": 2,
"max_steps": 100,
"learning_rate": 2e-4,
"fp16": True,
"logging_steps": 1,
"optim": "paged_adamw_8bit",
"max_grad_norm": 0.3,
"warmup_ratio": 0.03,
"group_by_length": True,
"lr_scheduler_type": "constant",
"ddp_find_unused_parameters": False,
}
huggingface_auto_trainer_fn = mlrun.import_function("hub://huggingface_auto_trainer")
training_run = mlrun.run_function(
function=huggingface_auto_trainer_fn,
name="auto-trainer",
local=True,
inputs={"train_dataset": dataset_artifact.uri},
params={
"model": [model_name, "transformers.AutoModelForCausalLM"],
"tokenizer": tokenizer,
"training_config": training_arguments,
"quantization_config": True,
"lora_config": True,
"dataset_columns_to_train": "text",
"lora_target_modules": lora_target_modules,
"model_pretrained_config": {"trust_remote_code": True, "use_cache": False},
},
handler="finetune_llm",
outputs=["model"],
)
> 2024-06-13 21:22:30,080 [info] Storing function: {'name': 'auto-trainer', 'uid': '51c6292500214d57b30ac35fb9037b87', 'db': 'http://mlrun-api:8080'}
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
trainable params: 19988480 || all params: 3520401408 || trainable%: 0.5677897967708119
max_steps is given, it will override any value given in num_train_epochs
> 2024-06-13 21:22:42,573 [info] training 'NousResearch/Llama-2-7b-hf'
torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
Step | Training Loss |
---|---|
1 | 2.699500 |
2 | 1.738200 |
3 | 1.727400 |
4 | 1.081600 |
5 | 1.704400 |
6 | 2.043700 |
7 | 1.327100 |
8 | 1.419800 |
9 | 1.540700 |
10 | 1.394800 |
11 | 1.098300 |
12 | 1.656400 |
13 | 1.110300 |
14 | 1.419500 |
15 | 1.456500 |
16 | 1.107100 |
17 | 1.265300 |
18 | 1.403700 |
19 | 1.015800 |
20 | 0.658500 |
21 | 0.843200 |
22 | 1.048500 |
23 | 1.290700 |
24 | 1.314500 |
25 | 0.872700 |
26 | 0.749600 |
27 | 1.278200 |
28 | 1.163500 |
29 | 1.057100 |
30 | 0.712800 |
31 | 1.108100 |
32 | 0.730400 |
33 | 1.077600 |
34 | 0.486200 |
35 | 0.819700 |
36 | 1.209400 |
37 | 1.084500 |
38 | 0.749300 |
39 | 1.292400 |
40 | 0.487600 |
41 | 0.787000 |
42 | 0.752900 |
43 | 0.378800 |
44 | 0.827400 |
45 | 0.568600 |
46 | 0.705800 |
47 | 0.327100 |
48 | 0.672700 |
49 | 0.361400 |
50 | 0.645200 |
51 | 0.612100 |
52 | 0.478700 |
53 | 0.447700 |
54 | 0.296300 |
55 | 0.274600 |
56 | 0.477300 |
57 | 0.279700 |
58 | 0.513100 |
59 | 0.273400 |
60 | 0.474600 |
61 | 0.470700 |
62 | 0.302100 |
63 | 0.280800 |
64 | 0.201300 |
65 | 0.234400 |
66 | 0.149700 |
67 | 0.304200 |
68 | 0.208100 |
69 | 0.181300 |
70 | 0.129400 |
71 | 0.190000 |
72 | 0.174500 |
73 | 0.286000 |
74 | 0.174800 |
75 | 0.182300 |
76 | 0.251000 |
77 | 0.270400 |
78 | 0.145000 |
79 | 0.134700 |
80 | 0.103700 |
81 | 0.117700 |
82 | 0.100700 |
83 | 0.188900 |
84 | 0.138500 |
85 | 0.057500 |
86 | 0.150500 |
87 | 0.079200 |
88 | 0.151500 |
89 | 0.205500 |
90 | 0.113900 |
91 | 0.119400 |
92 | 0.097500 |
93 | 0.130800 |
94 | 0.222800 |
95 | 0.082900 |
96 | 0.088600 |
97 | 0.091300 |
98 | 0.058600 |
99 | 0.051900 |
100 | 0.098300 |
/User/.pythonlibs/mlrun-extended/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning:
torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
/User/.pythonlibs/mlrun-extended/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning:
`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
project | uid | iter | start | state | name | labels | inputs | parameters | results | artifacts |
---|---|---|---|---|---|---|---|---|---|---|
fine-tune | 0 | Jun 13 21:22:30 | completed | auto-trainer | v3io_user=nick kind=local owner=nick host=jupyter-nick-7f68777f75-57td5 |
train_dataset |
model=['NousResearch/Llama-2-7b-hf', 'transformers.AutoModelForCausalLM'] tokenizer=NousResearch/Llama-2-7b-hf training_config={'per_device_train_batch_size': 4, 'gradient_accumulation_steps': 1, 'warmup_steps': 2, 'max_steps': 100, 'learning_rate': 0.0002, 'fp16': True, 'logging_steps': 1, 'optim': 'paged_adamw_8bit', 'max_grad_norm': 0.3, 'warmup_ratio': 0.03, 'group_by_length': True, 'lr_scheduler_type': 'constant', 'ddp_find_unused_parameters': False} quantization_config=True lora_config=True dataset_columns_to_train=text lora_target_modules=['q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'k_proj', 'v_proj'] model_pretrained_config={'trust_remote_code': True, 'use_cache': False} |
loss=0.0983 grad_norm=1.130889892578125 learning_rate=0.0002 train_runtime=1928.7363 train_samples_per_second=0.207 train_steps_per_second=0.052 total_flos=2.394428576956416e+16 |
loss_plot grad_norm_plot learning_rate_plot model |
> 2024-06-13 21:55:00,760 [info] Run execution finished: {'status': 'completed', 'name': 'auto-trainer'}
View the training artifacts#
In addition to the model itself, there are a number of plots and metrics that MLRun collects automatically.
training_run.outputs
{'loss': 0.0983,
'grad_norm': 1.130889892578125,
'learning_rate': 0.0002,
'train_runtime': 1928.7363,
'train_samples_per_second': 0.207,
'train_steps_per_second': 0.052,
'total_flos': 2.394428576956416e+16,
'loss_plot': 'v3io:///projects/fine-tune/artifacts/auto-trainer/0/loss_plot.html',
'grad_norm_plot': 'v3io:///projects/fine-tune/artifacts/auto-trainer/0/grad_norm_plot.html',
'learning_rate_plot': 'v3io:///projects/fine-tune/artifacts/auto-trainer/0/learning_rate_plot.html',
'model': 'store://artifacts/fine-tune/Llama-2-7b-hf@51c6292500214d57b30ac35fb9037b87'}
training_run.artifact("loss_plot").show()