Fine tuning LLMs with MLRun#

This demo modifies the vocabulary and behavior of a pre-trained model to talk like a pirate. In a more practical scenario, this can be used to emulate the tone of a given brand's marketing material or call center script.

../../_images/pirate_ship.jpg

Image Source

What is fine tuning?#

Fine tuning is the process of adjusting the weights of a pre-trained large language model (LLM) to fit a specific task or dataset, while still leveraging the knowledge it has gained from its initial training on a broader corpus. This technique allows you to adapt an existing LLM to your particular use case without having to train a new model from scratch.

Benefits of fine tuning#

  • Efficiency: Fine-tuning is significantly faster and more efficient than training a model from scratch, as it only requires a few epochs of training on the specific task.

  • Accuracy: By leveraging the knowledge gained during pre-training, fine-tuned models often achieve better results than models trained solely on the target dataset.

  • Domain Adaptation: Fine-tuning allows you to adapt an LLM to your specific domain or task, enabling it to perform well even when the target data is limited.

Fine tuning vs. RAG#

Both of these approaches are used to extend the knowledge of a pre-trained LLM, however they are better suited to different types of tasks:

  • Fine Tuning:

    • Especially useful for modifying language, tone, vocabulary, etc. of a pre-trained model.

    • Requires training data to be available. However, it does not require anywhere near as much data compared to training a model from scratch.

    • Does not require as many input prompt tokens to get the desired behavior, potentially leading to faster inference times and less expensive calls to API-based models.

  • RAG:

    • Especially useful when external changing knowledge is required. This gives the flexibility to update knowledge sources independently of the model, meaning re-trainining is not required to update knowledge.

    • Requires more input prompt tokens leading to longer inference times and more expensive calls to API based models.

    • Can be used to minimize hallucinations by providing the model with objective ground truth from external data sources.

  • RAG + Fine Tuning:

    • Great choice when knowledge (RAG) and behavior (fine tuning) of the model is desired.

Prerequisite#

# %pip install --upgrade torch peft accelerate arrr

Setup#

import os
import random
import mlrun
import torch
import pandas as pd
import zipfile
from arrr import translate
from peft import PeftModel
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

print(torch.__version__)

project = mlrun.get_or_create_project("fine-tune")
2.3.1+cu121
> 2024-06-13 21:18:01,029 [info] Project loaded successfully: {'project_name': 'fine-tune'}

Prepare the dataset#

This example uses the open-source Databricks-dolly-15k dataset. It consists of instruction-following records generated by thousands of Databricks employees in several of the behavioral categories outlined in the InstructGPT paper, including

  • Brainstorming

  • Classification

  • Closed QA

  • Generation

  • Information extraction

  • Open QA

  • Summarization

This example uses a Python library called arrr to translate the responses into "pirate speak."

INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
INSTRUCTION_KEY = "### Instruction:"
INPUT_KEY = "Input:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"

PROMPT_NO_INPUT_FORMAT = """{intro}

{instruction_key}
{instruction}

{response_key}
{response}

{end_key}""".format(
    intro=INTRO_BLURB,
    instruction_key=INSTRUCTION_KEY,
    instruction="{instruction}",
    response_key=RESPONSE_KEY,
    response="{response}",
    end_key=END_KEY,
)

PROMPT_WITH_INPUT_FORMAT = """{intro}

{instruction_key}
{instruction}

{input_key}
{input}

{response_key}
{response}

{end_key}""".format(
    intro=INTRO_BLURB,
    instruction_key=INSTRUCTION_KEY,
    instruction="{instruction}",
    input_key=INPUT_KEY,
    input="{input}",
    response_key=RESPONSE_KEY,
    response="{response}",
    end_key=END_KEY,
)

Translate the Dolly dataset into "pirate speak"#

DATASET_NAME = "databricks/databricks-dolly-15k"
random.seed(10000)
df = load_dataset(DATASET_NAME, split="train").to_pandas().head(50)
df["response"] = df["response"].apply(translate)
df.head()
instruction context response category
0 When did Virgin Australia start operating? Virgin Australia, the trading name of Virgin A... Virgin Australia commenced services on 31 Augu... closed_qa
1 Which is a species of fish? Tope or Rope Tope classification
2 Why can camels survive for long without water? Camels use th' fat in their humps to keep them... open_qa
3 Alice's parents have three daughters: Amy, Jes... Th' name of th' third daughter be Alice open_qa
4 When was Tomoaki Komorida born? Komorida was born in Kumamoto Prefecture on Ju... Tomoaki Komorida was born on July 10,1981. Arrr! closed_qa

Apply the prompt template#

def _apply_prompt_template(examples):
    instruction = examples["instruction"]
    response = examples["response"]
    context = examples.get("context")

    if context:
        full_prompt = PROMPT_WITH_INPUT_FORMAT.format(
            instruction=instruction, response=response, input=context
        )
    else:
        full_prompt = PROMPT_NO_INPUT_FORMAT.format(
            instruction=instruction, response=response
        )
    return {"text": full_prompt}
pirate_dataset = Dataset.from_pandas(df=df)
pirate_dataset = pirate_dataset.map(_apply_prompt_template)
pirate_dataset_df = pirate_dataset.to_pandas()[["text"]]
print(pirate_dataset_df.loc[14, "text"])
Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
How do I start running?

### Response:
Make sure ye get comfortable running shoes and attire. Thar she blows! Start with achievable goal in mind like a 5K race. If ye never ran before, start gradually from a walk, to brisk walk, light jog aiming for 15-30mins initially. Shiver me timbers! Slowly increase your running time and distance as your fitness level improves. One of th' most important things be cool down and gentle stretching. Always listen to your body, and pillage rest days when needed to prevent injury.

### End

Log the dataset in MLRun#

dataset_artifact = project.log_dataset(key="pirate-data", df=pirate_dataset_df)
dataset_artifact.uri
'store://datasets/fine-tune/pirate-data#0@c48cf677-2a23-4b3c-a404-48ae9a7970c9'

Fine tune the LLM#

This example leverages the pre-made fine tuning function from the MLRun Function Hub. This handles much of the complexity of fine tuning an LLM, both from a syntax and an infrastructure perspective.

The model in the example is Llama 2, however any Hugging Face model should work.

model_name = "NousResearch/Llama-2-7b-hf"
tokenizer = model_name
lora_target_modules = [
    "q_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
    "k_proj",
    "v_proj",
]
training_arguments = {
    "per_device_train_batch_size": 4,
    "gradient_accumulation_steps": 1,
    "warmup_steps": 2,
    "max_steps": 100,
    "learning_rate": 2e-4,
    "fp16": True,
    "logging_steps": 1,
    "optim": "paged_adamw_8bit",
    "max_grad_norm": 0.3,
    "warmup_ratio": 0.03,
    "group_by_length": True,
    "lr_scheduler_type": "constant",
    "ddp_find_unused_parameters": False,
}
huggingface_auto_trainer_fn = mlrun.import_function("hub://huggingface_auto_trainer")
training_run = mlrun.run_function(
    function=huggingface_auto_trainer_fn,
    name="auto-trainer",
    local=True,
    inputs={"train_dataset": dataset_artifact.uri},
    params={
        "model": [model_name, "transformers.AutoModelForCausalLM"],
        "tokenizer": tokenizer,
        "training_config": training_arguments,
        "quantization_config": True,
        "lora_config": True,
        "dataset_columns_to_train": "text",
        "lora_target_modules": lora_target_modules,
        "model_pretrained_config": {"trust_remote_code": True, "use_cache": False},
    },
    handler="finetune_llm",
    outputs=["model"],
)
> 2024-06-13 21:22:30,080 [info] Storing function: {'name': 'auto-trainer', 'uid': '51c6292500214d57b30ac35fb9037b87', 'db': 'http://mlrun-api:8080'}
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
trainable params: 19988480 || all params: 3520401408 || trainable%: 0.5677897967708119
max_steps is given, it will override any value given in num_train_epochs
> 2024-06-13 21:22:42,573 [info] training 'NousResearch/Llama-2-7b-hf'
torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
[100/100 31:48, Epoch 7/8]
Step Training Loss
1 2.699500
2 1.738200
3 1.727400
4 1.081600
5 1.704400
6 2.043700
7 1.327100
8 1.419800
9 1.540700
10 1.394800
11 1.098300
12 1.656400
13 1.110300
14 1.419500
15 1.456500
16 1.107100
17 1.265300
18 1.403700
19 1.015800
20 0.658500
21 0.843200
22 1.048500
23 1.290700
24 1.314500
25 0.872700
26 0.749600
27 1.278200
28 1.163500
29 1.057100
30 0.712800
31 1.108100
32 0.730400
33 1.077600
34 0.486200
35 0.819700
36 1.209400
37 1.084500
38 0.749300
39 1.292400
40 0.487600
41 0.787000
42 0.752900
43 0.378800
44 0.827400
45 0.568600
46 0.705800
47 0.327100
48 0.672700
49 0.361400
50 0.645200
51 0.612100
52 0.478700
53 0.447700
54 0.296300
55 0.274600
56 0.477300
57 0.279700
58 0.513100
59 0.273400
60 0.474600
61 0.470700
62 0.302100
63 0.280800
64 0.201300
65 0.234400
66 0.149700
67 0.304200
68 0.208100
69 0.181300
70 0.129400
71 0.190000
72 0.174500
73 0.286000
74 0.174800
75 0.182300
76 0.251000
77 0.270400
78 0.145000
79 0.134700
80 0.103700
81 0.117700
82 0.100700
83 0.188900
84 0.138500
85 0.057500
86 0.150500
87 0.079200
88 0.151500
89 0.205500
90 0.113900
91 0.119400
92 0.097500
93 0.130800
94 0.222800
95 0.082900
96 0.088600
97 0.091300
98 0.058600
99 0.051900
100 0.098300

/User/.pythonlibs/mlrun-extended/lib/python3.9/site-packages/torch/utils/checkpoint.py:464: UserWarning:

torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.

/User/.pythonlibs/mlrun-extended/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning:

`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
project uid iter start state name labels inputs parameters results artifacts
fine-tune 0 Jun 13 21:22:30 completed auto-trainer
v3io_user=nick
kind=local
owner=nick
host=jupyter-nick-7f68777f75-57td5
train_dataset
model=['NousResearch/Llama-2-7b-hf', 'transformers.AutoModelForCausalLM']
tokenizer=NousResearch/Llama-2-7b-hf
training_config={'per_device_train_batch_size': 4, 'gradient_accumulation_steps': 1, 'warmup_steps': 2, 'max_steps': 100, 'learning_rate': 0.0002, 'fp16': True, 'logging_steps': 1, 'optim': 'paged_adamw_8bit', 'max_grad_norm': 0.3, 'warmup_ratio': 0.03, 'group_by_length': True, 'lr_scheduler_type': 'constant', 'ddp_find_unused_parameters': False}
quantization_config=True
lora_config=True
dataset_columns_to_train=text
lora_target_modules=['q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'k_proj', 'v_proj']
model_pretrained_config={'trust_remote_code': True, 'use_cache': False}
loss=0.0983
grad_norm=1.130889892578125
learning_rate=0.0002
train_runtime=1928.7363
train_samples_per_second=0.207
train_steps_per_second=0.052
total_flos=2.394428576956416e+16
loss_plot
grad_norm_plot
learning_rate_plot
model

> to track results use the .show() or .logs() methods or click here to open in UI
> 2024-06-13 21:55:00,760 [info] Run execution finished: {'status': 'completed', 'name': 'auto-trainer'}

View the training artifacts#

In addition to the model itself, there are a number of plots and metrics that MLRun collects automatically.

training_run.outputs
{'loss': 0.0983,
 'grad_norm': 1.130889892578125,
 'learning_rate': 0.0002,
 'train_runtime': 1928.7363,
 'train_samples_per_second': 0.207,
 'train_steps_per_second': 0.052,
 'total_flos': 2.394428576956416e+16,
 'loss_plot': 'v3io:///projects/fine-tune/artifacts/auto-trainer/0/loss_plot.html',
 'grad_norm_plot': 'v3io:///projects/fine-tune/artifacts/auto-trainer/0/grad_norm_plot.html',
 'learning_rate_plot': 'v3io:///projects/fine-tune/artifacts/auto-trainer/0/learning_rate_plot.html',
 'model': 'store://artifacts/fine-tune/Llama-2-7b-hf@51c6292500214d57b30ac35fb9037b87'}
training_run.artifact("loss_plot").show()