Reach out
← Back to Cookbook

02 finetune and eval

Details

File: third_party/wandb/02_finetune_and_eval.ipynb

Type: Jupyter Notebook

Use Cases: Finetuning, Evaluation

Integrations: Wandb

Content

Notebook content (JSON format):

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/tcapelle/mistral_wandb/blob/main/02_finetune_and_eval.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "<!--- @wandbcode{wandb-mistral-webinar} -->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "urSaPCZoqEmM"
   },
   "source": [
    "# Mistral and Weights & Biases: Finetune an LLM judge so detect hallucination\n",
    "\n",
    "In this notebooks you will learn how to trace your MistralAI Api calls using W&B Weave, how to evaluate the performance of your models and how to close the gap by leveraging the MistralAI finetuning capabilities.\n",
    "\n",
    "In this notebooks we will fine-tune a mistral 7b model as an LLM Judge. This idea comes from the [amazing blog post from Eugene](https://eugeneyan.com/writing/finetuning/). The main goal is to fine-tune a small model like Mistral 7B to act as an hallucination judge. We will do this in 2 steps:\n",
    "- Training on a [Factual Inconsistency Benchmark](https://arxiv.org/abs/2211.08412v1) challenging dataset to improve the model performance to detect hallucination by detecting inconsistencies beween a piece of text and a \"summary\"\n",
    "- We will then mix that dataset with Wikipedia summaries dataset to increase the performance even more.\n",
    "\n",
    "[![](./static/eugene1.png)](https://eugeneyan.com/writing/finetuning/)\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "- Weights & Biases: https://wandb.ai/\n",
    "- Mistral finetuning docs: https://docs.mistral.ai/capabilities/finetuning/\n",
    "- Tracing with W&B Weave: https://wandb.me/weave"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load some data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install \"mistralai==0.4.2\" \"weave==0.50.7\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "let's import the relevant pieces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, asyncio, json\n",
    "from pathlib import Path\n",
    "\n",
    "import weave\n",
    "\n",
    "from mistralai.client import MistralClient\n",
    "from mistralai.models.chat_completion import ChatMessage\n",
    "\n",
    "client = MistralClient(api_key=os.environ[\"MISTRAL_API_KEY\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "some globals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_PATH = Path(\"./data\")\n",
    "NUM_SAMPLES = 100 # Number of samples to use for evaluation, use None for all\n",
    "PROJECT_NAME = \"llm-judge-webinar\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weave.init(PROJECT_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_jsonl(path):\n",
    "    \"returns a list of dictionaries\"\n",
    "    with open(path, 'r') as file:\n",
    "        return [json.loads(line) for line in file]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = read_jsonl(DATA_PATH / \"fib-train.jsonl\")\n",
    "val_ds = read_jsonl(DATA_PATH / \"fib-val.jsonl\")[0:NUM_SAMPLES]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./static/nli.png)\n",
    "\n",
    "We are going to map to 0 and 1 for the sake of it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will probably integrate MistralAI API calls in your codebase by creating a function like the one below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@weave.op()  # <---- add this and you are good to go\n",
    "def call_mistral(model:str, messages:list, **kwargs) -> str:\n",
    "    \"Call the Mistral API\"\n",
    "    chat_response = client.chat(\n",
    "        model=model,\n",
    "        messages=messages,\n",
    "        response_format={\"type\": \"json_object\"},\n",
    "        **kwargs,\n",
    "    )\n",
    "    return json.loads(chat_response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a prompt that explains the task..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"You are an expert to detect factual inconsistencies and hallucinations. You will be given a document and a summary.\n",
    "- Carefully read the full document and the provided summary.\n",
    "- Identify Factual Inconsistencies: any statements in the summary that are not supported by or contradict the information in the document.\n",
    "Factually Inconsistent: If any statement in the summary is not supported by or contradicts the document, label it as 0\n",
    "Factually Consistent: If all statements in the summary are supported by the document, label it as 1\n",
    "\n",
    "Highlight or list the specific statements in the summary that are inconsistent.\n",
    "Provide a brief explanation of why each highlighted statement is inconsistent with the document.\n",
    "\n",
    "Return in JSON format with `consistency` and a `reason` for the given choice.\n",
    "\n",
    "Document: \n",
    "{premise}\n",
    "Summary: \n",
    "{hypothesis}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_prompt(prompt, premise: str, hypothesis: str, cls=ChatMessage):\n",
    "    messages = [\n",
    "        cls(\n",
    "            role=\"user\", \n",
    "            content=prompt.format(premise=premise, hypothesis=hypothesis)\n",
    "        )\n",
    "    ]\n",
    "    return messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "premise, hypothesis, target = train_ds[1]['premise'], train_ds[1]['hypothesis'], train_ds[1]['target']\n",
    "messages=format_prompt(prompt, premise, hypothesis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = call_mistral(model=\"open-mistral-7b\", messages=messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MistralModel(weave.Model):\n",
    "    model: str\n",
    "    prompt: str\n",
    "    temperature: float = 0.7\n",
    "    \n",
    "    @weave.op\n",
    "    def create_messages(self, premise:str, hypothesis:str):\n",
    "        return format_prompt(self.prompt, premise, hypothesis)\n",
    "\n",
    "    @weave.op\n",
    "    def predict(self, premise:str, hypothesis:str):\n",
    "        messages = self.create_messages(premise, hypothesis)\n",
    "        return call_mistral(model=self.model, messages=messages, temperature=self.temperature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_7b = MistralModel(model=\"open-mistral-7b\", prompt=prompt, temperature=0.7)\n",
    "output = model_7b.predict(premise, hypothesis)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Eval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's evaluate the model on the validation split of the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(model_output, target):\n",
    "    class_model_output = model_output.get('consistency') if model_output else None\n",
    "    return {\"accuracy\": class_model_output == target}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BinaryMetrics(weave.Scorer):\n",
    "    class_name: str\n",
    "    eps: float = 1e-8\n",
    "\n",
    "    @weave.op()\n",
    "    def summarize(self, score_rows) -> dict:\n",
    "        # filter out None rows, model may error out sometimes...\n",
    "        score_rows = [score for score in score_rows if score[\"correct\"] is not None]\n",
    "        # Compute f1, precision, recall\n",
    "        tp = sum([not score[\"negative\"] and score[\"correct\"] for score in score_rows])\n",
    "        fp = sum([not score[\"negative\"] and not score[\"correct\"] for score in score_rows])\n",
    "        fn = sum([score[\"negative\"] and not score[\"correct\"] for score in score_rows])\n",
    "        precision = tp / (tp + fp + self.eps)\n",
    "        recall = tp / (tp + fn + self.eps)\n",
    "        f1 = 2 * precision * recall / (precision + recall + self.eps)\n",
    "        result = {\"f1\": f1, \"precision\": precision, \"recall\": recall}\n",
    "        return result\n",
    "\n",
    "    @weave.op()\n",
    "    def score(self, target: dict, model_output: dict) -> dict:\n",
    "        class_model_output = model_output.get(self.class_name) if model_output else None  # 0 or 1\n",
    "        result = {\n",
    "            \"correct\": class_model_output == target,\n",
    "            \"negative\": not class_model_output,\n",
    "        }\n",
    "        return result\n",
    "\n",
    "F1 = BinaryMetrics(class_name=\"consistency\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation = weave.Evaluation(dataset=val_ds, scorers=[accuracy, F1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "await evaluation.evaluate(model_7b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./static/eval_7b.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Iterate a bit on the prompt..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try adding the example from Eugene's blog:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_example = \"\"\"You are an expert to detect factual inconsistencies and hallucinations. You will be given a document and a summary.\n",
    "- Carefully read the full document and the provided summary.\n",
    "- Identify Factual Inconsistencies: any statements in the summary that are not supported by or contradict the information in the document.\n",
    "Factually Inconsistent: If any statement in the summary is not supported by or contradicts the document, label it as 0\n",
    "Factually Consistent: If all statements in the summary are supported by the document, label it as 1\n",
    "\n",
    "Here you have an example:\n",
    "\n",
    "Document: \n",
    "Vehicles and pedestrians will now embark and disembark the Cowes ferry separately following Maritime and Coastguard Agency (MCA) guidance. \n",
    "Isle of Wight Council said its new procedures were in response to a resident’s complaint. Councillor Shirley Smart said it would \n",
    "“initially result in a slower service”. Originally passengers and vehicles boarded or disembarked the so-called “floating bridge” at the same time. \n",
    "Ms Smart, who is the executive member for economy and tourism, said the council already had measures in place to control how passengers \n",
    "and vehicles left or embarked the chain ferry “in a safe manner”. However, it was “responding” to the MCA’s recommendations “following this \n",
    "complaint”. She added: “This may initially result in a slower service while the measures are introduced and our customers get used to \n",
    "the changes.” The service has been in operation since 1859.\n",
    "\n",
    "Inconsistent summary: A new service on the Isle of Wight’s chain ferry has been launched following a complaint from a resident.\n",
    "\n",
    "Consistent summary: Passengers using a chain ferry have been warned crossing times will be longer because of new safety measures.\n",
    "\n",
    "Highlight or list the specific statements in the summary that are inconsistent.\n",
    "Provide a brief explanation of why each highlighted statement is inconsistent with the document.\n",
    "\n",
    "Return in JSON format with `consistency` and a `reason` for the given choice.\n",
    "\n",
    "Document: \n",
    "{premise}\n",
    "Summary: \n",
    "{hypothesis}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_7b_ex = MistralModel(model=\"open-mistral-7b\", prompt=prompt_example, temperature=0.7)\n",
    "output = model_7b_ex.predict(premise, hypothesis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a hard dataset!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await evaluation.evaluate(model_7b_ex)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Large"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_large = MistralModel(model=\"mistral-large-latest\", prompt=prompt_example, temperature=0.7)\n",
    "await evaluation.evaluate(model_large)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./static/eval_large.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This model is considerably better! over 80% accuracy is great on this hard task 😎"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-Tune FTW"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see if fine-tuning improves this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_prompt = \"\"\"You are an expert to detect factual inconsistencies and hallucinations. You will be given a document and a summary.\n",
    "- Carefully read the full document and the provided summary.\n",
    "- Identify Factual Inconsistencies: any statements in the summary that are not supported by or contradict the information in the document.\n",
    "Factually Inconsistent: If any statement in the summary is not supported by or contradicts the document, label it as 0\n",
    "Factually Consistent: If all statements in the summary are supported by the document, label it as 1\n",
    "\n",
    "Return in JSON format with `consistency` for the given choice.\n",
    "\n",
    "Document: \n",
    "{premise}\n",
    "Summary: \n",
    "{hypothesis}\n",
    "\"\"\"\n",
    "\n",
    "answer = \"\"\"{{\"consistency\": {label}}}\"\"\"  # <- json schema"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need to format your prompts slightly different for FT\n",
    "- instead of `ChatMessage` use a `dict`\n",
    "- Add the output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_prompt_ft(row, cls=dict, with_answer=True):\n",
    "    \"Format on the expected MistralAI fine-tuning dataset\"\n",
    "    premise = row['premise']\n",
    "    hypothesis = row['hypothesis']\n",
    "    messages = [\n",
    "        cls(\n",
    "            role=\"user\", \n",
    "            content=prompt.format(premise=premise, hypothesis=hypothesis)\n",
    "        )\n",
    "    ]\n",
    "    if with_answer:\n",
    "        label = row['target']\n",
    "        messages.append(\n",
    "            cls(\n",
    "                role=\"assistant\",\n",
    "            content=answer.format(label=label)\n",
    "            )\n",
    "        )\n",
    "    return messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_prompt_ft(train_ds[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You could use other fancy datasets or pandas, but this is a small dataset so let's not add more complexity..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "formatted_train_ds = [format_prompt_ft(row) for row in train_ds]\n",
    "formatted_val_ds = [format_prompt_ft(row) for row in val_ds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_jsonl(ds, path):\n",
    "    with open(path, \"w\") as f:\n",
    "        for row in ds:\n",
    "            f.write(json.dumps(row) + \"\\n\")\n",
    "save_jsonl(formatted_train_ds, DATA_PATH/\"formatted_fib_train.jsonl\")\n",
    "save_jsonl(formatted_val_ds, DATA_PATH/\"formatted_fib_val.jsonl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from mistralai.client import MistralClient\n",
    "\n",
    "api_key = os.environ.get(\"MISTRAL_API_KEY\")\n",
    "client = MistralClient(api_key=api_key)\n",
    "\n",
    "with open(DATA_PATH/\"formatted_fib_train.jsonl\", \"rb\") as f:\n",
    "    ds_train = client.files.create(file=(\"formatted_df_train.jsonl\", f))\n",
    "with open(DATA_PATH/\"formatted_fib_val.jsonl\", \"rb\") as f:\n",
    "    ds_eval = client.files.create(file=(\"eval.jsonl\", f))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "def pprint(obj):\n",
    "    print(json.dumps(obj.dict(), indent=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pprint(ds_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a fine-tuning job"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ok, now let's create a fine-tune job with the mistral api. Some thing to know:\n",
    "- You only have 2 parameters to play wtih: `training_steps` and `learning_rate`\n",
    "- You can use `dry_run=True` to get an estimate cost\n",
    "- `training_steps` is not exactly linked to epochs in a direct way, they have a rule of thumbs on the docs. If you do a dry run the epochs will be calculated for you.\n",
    "\n",
    "We want to run for 10 epochs to reproduce Eugene's results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralai.models.jobs import TrainingParameters, WandbIntegrationIn\n",
    "\n",
    "created_jobs = client.jobs.create(\n",
    "    # dry_run=True,\n",
    "    model=\"open-mistral-7b\",\n",
    "    training_files=[ds_train.id],\n",
    "    validation_files=[ds_eval.id],\n",
    "    hyperparameters=TrainingParameters(\n",
    "        training_steps=35,\n",
    "        learning_rate=0.0001,\n",
    "        ),\n",
    "    integrations=[\n",
    "        WandbIntegrationIn(\n",
    "            project=PROJECT_NAME,\n",
    "            run_name=\"mistral_7b_fib\",\n",
    "            api_key=os.environ.get(\"WANDB_API_KEY\"),\n",
    "        ).dict()\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pprint(created_jobs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "retrieved_job = client.jobs.retrieve(created_jobs.id)\n",
    "while retrieved_job.status in [\"RUNNING\", \"QUEUED\"]:\n",
    "    retrieved_job = client.jobs.retrieve(created_jobs.id)\n",
    "    pprint(retrieved_job)\n",
    "    print(f\"Job is {retrieved_job.status}, waiting 10 seconds\")\n",
    "    time.sleep(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](./static/ft_dashboard.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use a fine-tuned model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's compute the predictions using the fine-tuned 7B model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jobs = client.jobs.list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retrieved_job = jobs.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retrieved_job.fine_tuned_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mistral_7b_ft = MistralModel(prompt=ft_prompt, model=retrieved_job.fine_tuned_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "await evaluation.evaluate(mistral_7b_ft)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "quite substantial improvement! Some take aways:\n",
    "- the Mistral 7B is a much more powerful model than the original Bart that eugene was using on his blog post\n",
    "- With a relatively small high quality dataset the improvements for this downstream task are enormous!\n",
    "- Now we can leverage a faster and cheaper 7B instead of taping into `mistral-large`. Of course we could have some filtering logic to decide when to use the big gun anyway."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pre-finetuning on USB to improve performance on FIB\n",
    "\n",
    "The [Unified Summarization Benchmark (USB)](https://arxiv.org/abs/2305.14296) is made up of eight summarization tasks including abstractive summarization, evidence extraction, and factuality classification. While FIB documents are based on news, USB documents are based on a different domain—Wikipedia. Labels for factual consistency were created based on edits to summary sentences; inconsistent and consistent labels were assigned to the before and after versions respectively. Here’s the first sample in the dataset:\n",
    "\n",
    "> Check Eugene's Analysis [here](https://eugeneyan.com/writing/finetuning/#pre-finetuning-on-usb-to-improve-performance-on-fib)\n",
    "\n",
    "\n",
    "Let's mix the USB dataset in the training data..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds_usb = read_jsonl(DATA_PATH / \"usb-train.jsonl\")\n",
    "\n",
    "formatted_train_usb_ds = [format_prompt_ft(row) for row in train_ds_usb]\n",
    "save_jsonl(formatted_train_usb_ds, DATA_PATH/\"formatted_train_usb.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(DATA_PATH/\"formatted_train_usb.jsonl\", \"rb\") as f:\n",
    "    ds_train_usb = client.files.create(file=(\"formatted_df_train_usb.jsonl\", f))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mistralai.models.jobs import TrainingParameters, WandbIntegrationIn\n",
    "\n",
    "created_jobs = client.jobs.create(\n",
    "    # dry_run=True,\n",
    "    model=\"open-mistral-7b\",\n",
    "    training_files=[ds_train.id, ds_train_usb.id], # <- just add this new file\n",
    "    validation_files=[ds_eval.id],\n",
    "    hyperparameters=TrainingParameters(\n",
    "        training_steps=200,\n",
    "        learning_rate=0.0001,\n",
    "        ),\n",
    "    integrations=[\n",
    "        WandbIntegrationIn(\n",
    "            project=PROJECT_NAME,\n",
    "            run_name=\"mistral_7b_fib_usb\", # <- change the run name\n",
    "            api_key=os.environ.get(\"WANDB_API_KEY\"),\n",
    "        ).dict()\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jobs = client.jobs.list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "created_jobs = jobs.data[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pprint(created_jobs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retrieved_job = client.jobs.retrieve(created_jobs.id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retrieved_job.fine_tuned_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mistral_7b_usb_ft = MistralModel(prompt=ft_prompt, model=retrieved_job.fine_tuned_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await evaluation.evaluate(mistral_7b_usb_ft)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Final results\n",
    "\n",
    "The fine-tuned model over USB + FIB is now 90%+ accurate!\n",
    "\n",
    "![](./static/compare.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}