← Back to Cookbook
product classification
Details
File: mistral/classifier_factory/product_classification.ipynb
Type: Jupyter Notebook
Use Cases: Classification
Content
Notebook content (JSON format):
{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "AcooU1NoTWl4" }, "source": [ "# Product Classification: Customise your own classifier for tailored food categorization\n", "\n", "In this cookbook, we will delve into classification, specifically focusing on how to leverage the Classifier Factory to create classifiers tailored to your needs and use cases.\n", "\n", "For simplicity, we will concentrate on a specific example that requires multitarget classification.\n", "\n", "## Food Classification\n", "\n", "The specific use case we will explore is food classification. We aim to classify different dishes and recipes into various categories and further classify them by the main language of the recipe.\n", "\n", "We will focus on three values:\n", "- The dish or food name\n", "- The country it belongs to\n", "- The multi-categories\n", "\n", "This means we need to classify two main aspects: the country and the categories to which the food belongs.\n", "\n", "We will also arbitrarily decide that there should be no food without any category; there should always be at least one.\n", "\n", "## Dataset\n", "\n", "For this purpose, we will use a [subset](https://huggingface.co/datasets/pandora-s/openfood-classification) of the [Open Food Facts product database](https://huggingface.co/datasets/openfoodfacts/product-database) as the data relevant to our use case.\n", "\n", "This subset was curated to focus on the most prevalent labels and underwent a few steps for balancing.\n", "\n", "### Labels\n", "There are 2 main labels:\n", "- Country *single target*: The corresponding country of the food/dish among 8 possible values: `italy`, `spain`, `germany`, `france`, `united-states`, `belgium`, `united-kingdom` and `switzerland`.\n", "- Category *multi-target*: The category it belongs to among 8 possible values: `snacks`, `beverages`, `cereals-and-potatoes`, `plant-based-foods`, `dairies`, `plant-based-foods-and-beverages`, `meats-and-their-products` and `sweet-snacks`.\n", "\n", "There are 8 countries and 8 different categories.\n", "Due to the nature of each label, the dataset is split as follows:\n", "- `name`: The name of the food/dish, extracted from the `product_name` of the openfoodfacts/product-database dataset.\n", "- `country_label`: The country ID, extracted from `countries_tags` of the openfoodfacts/product-database dataset.\n", "- `category_labels`: The categories it belongs to, extracted from `categories_tags` of the openfoodfacts/product-database dataset.\n", "\n", "### Distribution\n", "\n", "Note that the food categories overlap each other, since a sample can have multiple categories.\n", "\n", "### Splits\n", "The dataset was split into 3 sets:\n", "- `train`: 80%\n", "- `validation`: 10%\n", "- `test`: 10%" ] }, { "cell_type": "markdown", "metadata": { "id": "7lZo_t--T1FV" }, "source": [ "### Data Preparation\n", "Lets download the dataset, we will install `datasets` and load it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YAwwzneFhci9" }, "outputs": [], "source": [ "%%capture\n", "!pip install datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wDe_aYx5hd7p" }, "outputs": [], "source": [ "%%capture\n", "from datasets import load_dataset\n", "\n", "dataset = load_dataset('pandora-s/openfood-classification')\n", "dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "3g0lSUNVUtgK" }, "source": [ "We can take a look at the test set directly via colab by converting it to a pandas dataframe." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "id": "fnH0zEJ8UtHq", "outputId": "c11e1dc1-6e8a-4e40-a215-6e565cbf91a3" }, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>name</th>\n", " <th>country_label</th>\n", " <th>category_labels</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Avena e nocciole cioccolato fondente</td>\n", " <td>italy</td>\n", " <td>{'plant-based-foods': 'false', 'cereals-and-po...</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>Pomodori in pezzi</td>\n", " <td>belgium</td>\n", " <td>{'plant-based-foods': 'false', 'cereals-and-po...</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>Grandyoats, Nori Sesame Cashews</td>\n", " <td>united-states</td>\n", " <td>{'plant-based-foods': 'false', 'cereals-and-po...</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>Jus d'orange Profit</td>\n", " <td>switzerland</td>\n", " <td>{'plant-based-foods': 'false', 'cereals-and-po...</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>Rote Beete</td>\n", " <td>germany</td>\n", " <td>{'plant-based-foods': 'true', 'cereals-and-pot...</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>10030</th>\n", " <td>Yaourt doux</td>\n", " <td>france</td>\n", " <td>{'plant-based-foods': 'false', 'cereals-and-po...</td>\n", " </tr>\n", " <tr>\n", " <th>10031</th>\n", " <td>Mirtillo di bosco</td>\n", " <td>italy</td>\n", " <td>{'plant-based-foods': 'false', 'cereals-and-po...</td>\n", " </tr>\n", " <tr>\n", " <th>10032</th>\n", " <td>Rôti de porc cuit supérieur</td>\n", " <td>france</td>\n", " <td>{'plant-based-foods': 'false', 'cereals-and-po...</td>\n", " </tr>\n", " <tr>\n", " <th>10033</th>\n", " <td>Mix de vegetales con pepinillo</td>\n", " <td>spain</td>\n", " <td>{'plant-based-foods': 'true', 'cereals-and-pot...</td>\n", " </tr>\n", " <tr>\n", " <th>10034</th>\n", " <td>Olives vertes ail et persil bio</td>\n", " <td>france</td>\n", " <td>{'plant-based-foods': 'false', 'cereals-and-po...</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>10035 rows × 3 columns</p>\n", "</div>" ], "text/plain": [ " name country_label \\\n", "0 Avena e nocciole cioccolato fondente italy \n", "1 Pomodori in pezzi belgium \n", "2 Grandyoats, Nori Sesame Cashews united-states \n", "3 Jus d'orange Profit switzerland \n", "4 Rote Beete germany \n", "... ... ... \n", "10030 Yaourt doux france \n", "10031 Mirtillo di bosco italy \n", "10032 Rôti de porc cuit supérieur france \n", "10033 Mix de vegetales con pepinillo spain \n", "10034 Olives vertes ail et persil bio france \n", "\n", " category_labels \n", "0 {'plant-based-foods': 'false', 'cereals-and-po... \n", "1 {'plant-based-foods': 'false', 'cereals-and-po... \n", "2 {'plant-based-foods': 'false', 'cereals-and-po... \n", "3 {'plant-based-foods': 'false', 'cereals-and-po... \n", "4 {'plant-based-foods': 'true', 'cereals-and-pot... \n", "... ... \n", "10030 {'plant-based-foods': 'false', 'cereals-and-po... \n", "10031 {'plant-based-foods': 'false', 'cereals-and-po... \n", "10032 {'plant-based-foods': 'false', 'cereals-and-po... \n", "10033 {'plant-based-foods': 'true', 'cereals-and-pot... \n", "10034 {'plant-based-foods': 'false', 'cereals-and-po... \n", "\n", "[10035 rows x 3 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = dataset[\"test\"].to_pandas()\n", "df" ] }, { "cell_type": "markdown", "metadata": { "id": "9HKm6rNHUF_n" }, "source": [ "Now that we have loaded our dataset, we will convert it to the proper desired format to upload for training.\n", "\n", "The data will be converted to a jsonl format as follows:\n", "```json\n", "{\"text\": \"Avena e nocciole cioccolato fondente\", \"labels\": {\"food\": [\"sweet-snacks\"], \"country_label\": \"italy\"}}\n", "{\"text\": \"Pomodori in pezzi\", \"labels\": {\"food\": [\"plant-based-foods-and-beverages\"], \"country_label\": \"belgium\"}}\n", "{\"text\": \"Grandyoats, Nori Sesame Cashews\", \"labels\": {\"food\": [\"snacks\"], \"country_label\": \"united-states\"}}\n", "{\"text\": \"Jus d'orange Profit\", \"labels\": {\"food\": [\"beverages\", \"plant-based-foods-and-beverages\"], \"country_label\": \"switzerland\"}}\n", "{\"text\": \"Rote Beete\", \"labels\": {\"food\": [\"plant-based-foods\", \"plant-based-foods-and-beverages\"], \"country_label\": \"germany\"}}\n", "...\n", "```\n", "With an example of a label being:\n", "```json\n", "\"labels\": {\n", " \"food\": [\n", " \"beverages\",\n", " \"plant-based-foods-and-beverages\"\n", " ],\n", " \"country_label\": \"switzerland\"\n", "}\n", "```\n", "For multi-target classification." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ltHoh_QXJh0n", "outputId": "64499d2e-f496-4fdc-b58a-d57dfa8f4385" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 80273/80273 [00:02<00:00, 38886.76it/s]\n", "100%|██████████| 10034/10034 [00:00<00:00, 28462.58it/s]\n", "100%|██████████| 10035/10035 [00:00<00:00, 42115.74it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "JSONL files have been saved.\n" ] } ], "source": [ "from tqdm import tqdm\n", "import json\n", "\n", "def dataset_to_jsonl(split):\n", " jsonl_data = []\n", "\n", " all_category_labels = set()\n", " all_countries = set()\n", "\n", " # Collect all unique category labels and countries\n", " for example in dataset[split]:\n", " all_category_labels.update(example['category_labels'].keys())\n", " all_countries.add(example['country_label'])\n", "\n", " # Convert sets to sorted lists for consistent formatting\n", " all_category_labels = sorted(all_category_labels)\n", " all_countries = sorted(all_countries)\n", "\n", " # Process each example in the split\n", " for example in tqdm(dataset[split]):\n", " labels = {\n", " \"food\": [\n", " tag\n", " for tag in all_category_labels\n", " if example['category_labels'][tag] == \"true\"\n", " ]\n", " }\n", " labels[\"country_label\"] = example['country_label']\n", "\n", " jsonl_data.append({\n", " \"text\": example['name'],\n", " \"labels\": labels\n", " })\n", "\n", " return jsonl_data, all_category_labels, all_countries\n", "\n", "# Process each split\n", "train_jsonl, _, _ = dataset_to_jsonl('train')\n", "validation_jsonl, _, _ = dataset_to_jsonl('validation')\n", "test_jsonl, all_category_labels, all_country_labels = dataset_to_jsonl('test')\n", "\n", "# Save the formatted data as JSONL files\n", "for split, data in zip(['train', 'validation', 'test'], [train_jsonl, validation_jsonl, test_jsonl]):\n", " with open(f'{split}_openfood_classification.jsonl', 'w') as f:\n", " for entry in data:\n", " f.write(json.dumps(entry) + '\\n')\n", "\n", "print(\"JSONL files have been saved.\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "b-wOQ0B9WCwm" }, "source": [ "The data was converted and saved properly. We can now train our model.\n", "\n", "## Training\n", "There are two methods to train the model: either upload and train via [la platforme](https://console.mistral.ai/build/finetuned-models) or via the [API](https://classifier-factory.platform-docs-9m1.pages.dev/capabilities/finetuning/classifier_factory/).\n", "\n", "First, we need to install `mistralai`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Td3wp01pWJkC", "outputId": "17f15cbc-1047-45ea-a581-c1a897b805a0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "zsh:1: command not found: pip\n" ] } ], "source": [ "!pip install mistralai" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HLDxqIh_WRAu" }, "outputs": [], "source": [ "from mistralai import Mistral\n", "\n", "# Set the API key for Mistral\n", "api_key = \"API_KEY\"\n", "\n", "# Set your Weights and Biases key\n", "wandb_key = \"WANDB_KEY\"\n", "\n", "# Initialize the Mistral client\n", "client = Mistral(api_key=api_key)" ] }, { "cell_type": "markdown", "metadata": { "id": "FqzOnqF00HwF" }, "source": [ "We will upload 2 files, the training set and the validation set ( optional ) that will be used for validation loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WQPCfr5izp7a" }, "outputs": [], "source": [ "# Upload the training data\n", "training_data = client.files.upload(\n", " file={\n", " \"file_name\": \"train_openfood_classification.jsonl\",\n", " \"content\": open(\"train_openfood_classification.jsonl\", \"rb\"),\n", " }\n", ")\n", "\n", "# Upload the validation data\n", "validation_data = client.files.upload(\n", " file={\n", " \"file_name\": \"validation_openfood_classification.jsonl\",\n", " \"content\": open(\"validation_openfood_classification.jsonl\", \"rb\"),\n", " }\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "kz9agHvZ0JR5" }, "source": [ "With the data uploaded, we can create a job.\n", "\n", "We allow users to keep track of aconsiderable amount of metrics via our Weights and Biases integration that we strongly recommend, you can make use of it by providing the project name and your key." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i6w1mH7v0CJq", "outputId": "6cad81cb-a4db-44f6-85ff-235e6c9272dc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"id\": \"1905e7d8-c6b1-4eb1-b349-e656944276d6\",\n", " \"auto_start\": false,\n", " \"model\": \"ministral-3b-latest\",\n", " \"status\": \"QUEUED\",\n", " \"created_at\": 1744810479,\n", " \"modified_at\": 1744810479,\n", " \"training_files\": [\n", " \"7587cbfd-0c3e-4413-834e-b7e1c588d892\"\n", " ],\n", " \"hyperparameters\": {\n", " \"training_steps\": 250,\n", " \"learning_rate\": 7e-05,\n", " \"weight_decay\": 0.1,\n", " \"warmup_fraction\": 0.05,\n", " \"epochs\": null,\n", " \"seq_len\": 16384\n", " },\n", " \"validation_files\": [\n", " \"e37c81ca-01c0-4b78-b335-8682d961e8be\"\n", " ],\n", " \"fine_tuned_model\": null,\n", " \"suffix\": null,\n", " \"integrations\": [\n", " {\n", " \"project\": \"product-classifier\",\n", " \"name\": null,\n", " \"run_name\": null,\n", " \"url\": null\n", " }\n", " ],\n", " \"trained_tokens\": null,\n", " \"metadata\": {\n", " \"expected_duration_seconds\": null,\n", " \"cost\": 0.0,\n", " \"cost_currency\": null,\n", " \"train_tokens_per_step\": null,\n", " \"train_tokens\": null,\n", " \"data_tokens\": null,\n", " \"estimated_start_time\": null\n", " }\n", "}\n" ] } ], "source": [ "# Create a fine-tuning job\n", "created_job = client.fine_tuning.jobs.create(\n", " model=\"ministral-3b-latest\",\n", " job_type=\"classifier\",\n", " training_files=[{\"file_id\": training_data.id, \"weight\": 1}],\n", " validation_files=[validation_data.id],\n", " hyperparameters={\"training_steps\": 250, \"learning_rate\": 0.00007},\n", " auto_start=False,\n", " integrations=[\n", " {\n", " \"project\": \"product-classifier\",\n", " \"api_key\": wandb_key,\n", " }\n", " ]\n", ")\n", "print(json.dumps(created_job.model_dump(), indent=4))" ] }, { "cell_type": "markdown", "metadata": { "id": "sKYBTLEL0dFd" }, "source": [ "Once the job is created, we can review details such as the number of epochs and other relevant information. This allows us to make informed decisions before initiating the job.\n", "\n", "We'll retrieve the job and wait for it to complete the validation process before starting. This validation step ensures the job is ready to begin." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OHPqiWt10Eyh", "outputId": "b4229360-1f20-47aa-dff2-53cdf3157853" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"id\": \"1905e7d8-c6b1-4eb1-b349-e656944276d6\",\n", " \"auto_start\": false,\n", " \"model\": \"ministral-3b-latest\",\n", " \"status\": \"VALIDATED\",\n", " \"created_at\": 1744810479,\n", " \"modified_at\": 1744810483,\n", " \"training_files\": [\n", " \"7587cbfd-0c3e-4413-834e-b7e1c588d892\"\n", " ],\n", " \"hyperparameters\": {\n", " \"training_steps\": 250,\n", " \"learning_rate\": 7e-05,\n", " \"weight_decay\": 0.1,\n", " \"warmup_fraction\": 0.05,\n", " \"epochs\": 6.414733841792868,\n", " \"seq_len\": 16384\n", " },\n", " \"classifier_targets\": [\n", " {\n", " \"name\": \"food\",\n", " \"labels\": [\n", " \"plant-based-foods-and-beverages\",\n", " \"meats-and-their-products\",\n", " \"sweet-snacks\",\n", " \"snacks\",\n", " \"dairies\",\n", " \"plant-based-foods\",\n", " \"cereals-and-potatoes\",\n", " \"beverages\"\n", " ]\n", " },\n", " {\n", " \"name\": \"country_label\",\n", " \"labels\": [\n", " \"united-states\",\n", " \"italy\",\n", " \"germany\",\n", " \"france\",\n", " \"switzerland\",\n", " \"united-kingdom\",\n", " \"spain\",\n", " \"belgium\"\n", " ]\n", " }\n", " ],\n", " \"validation_files\": [\n", " \"e37c81ca-01c0-4b78-b335-8682d961e8be\"\n", " ],\n", " \"fine_tuned_model\": null,\n", " \"suffix\": null,\n", " \"integrations\": [\n", " {\n", " \"project\": \"product-classifier\",\n", " \"name\": null,\n", " \"run_name\": null,\n", " \"url\": null\n", " }\n", " ],\n", " \"trained_tokens\": null,\n", " \"metadata\": {\n", " \"expected_duration_seconds\": 1000,\n", " \"cost\": 8.2,\n", " \"cost_currency\": \"EUR\",\n", " \"train_tokens_per_step\": 65536,\n", " \"train_tokens\": 16384000,\n", " \"data_tokens\": 2554120,\n", " \"estimated_start_time\": null\n", " },\n", " \"events\": [\n", " {\n", " \"name\": \"status-updated\",\n", " \"created_at\": 1744810479,\n", " \"data\": {\n", " \"status\": \"QUEUED\"\n", " }\n", " },\n", " {\n", " \"name\": \"status-updated\",\n", " \"created_at\": 1744810480,\n", " \"data\": {\n", " \"status\": \"VALIDATING\"\n", " }\n", " },\n", " {\n", " \"name\": \"status-updated\",\n", " \"created_at\": 1744810483,\n", " \"data\": {\n", " \"status\": \"VALIDATED\"\n", " }\n", " }\n", " ],\n", " \"checkpoints\": []\n", "}\n" ] } ], "source": [ "# Retrieve the job details\n", "retrieved_job = client.fine_tuning.jobs.get(job_id=created_job.id)\n", "print(json.dumps(retrieved_job.model_dump(), indent=4))\n", "\n", "import time\n", "from IPython.display import clear_output\n", "\n", "# Wait for the job to be validated\n", "while retrieved_job.status not in [\"VALIDATED\"]:\n", " retrieved_job = client.fine_tuning.jobs.get(job_id=created_job.id)\n", "\n", " clear_output(wait=True) # Clear the previous output (User Friendly)\n", " print(json.dumps(retrieved_job.model_dump(), indent=4))\n", " time.sleep(1)" ] }, { "cell_type": "markdown", "metadata": { "id": "qr1uaK9L0fz4" }, "source": [ "We can now run the job." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gaRvSsmN0iB0", "outputId": "d1357fb4-7262-44b5-fcea-d0af0e9457e5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"id\": \"1905e7d8-c6b1-4eb1-b349-e656944276d6\",\n", " \"auto_start\": false,\n", " \"model\": \"ministral-3b-latest\",\n", " \"status\": \"QUEUED\",\n", " \"created_at\": 1744810479,\n", " \"modified_at\": 1744810486,\n", " \"training_files\": [\n", " \"7587cbfd-0c3e-4413-834e-b7e1c588d892\"\n", " ],\n", " \"hyperparameters\": {\n", " \"training_steps\": 250,\n", " \"learning_rate\": 7e-05,\n", " \"weight_decay\": 0.1,\n", " \"warmup_fraction\": 0.05,\n", " \"epochs\": 6.414733841792868,\n", " \"seq_len\": 16384\n", " },\n", " \"classifier_targets\": [\n", " {\n", " \"name\": \"food\",\n", " \"labels\": [\n", " \"plant-based-foods-and-beverages\",\n", " \"meats-and-their-products\",\n", " \"sweet-snacks\",\n", " \"snacks\",\n", " \"dairies\",\n", " \"plant-based-foods\",\n", " \"cereals-and-potatoes\",\n", " \"beverages\"\n", " ]\n", " },\n", " {\n", " \"name\": \"country_label\",\n", " \"labels\": [\n", " \"united-states\",\n", " \"italy\",\n", " \"germany\",\n", " \"france\",\n", " \"switzerland\",\n", " \"united-kingdom\",\n", " \"spain\",\n", " \"belgium\"\n", " ]\n", " }\n", " ],\n", " \"validation_files\": [\n", " \"e37c81ca-01c0-4b78-b335-8682d961e8be\"\n", " ],\n", " \"fine_tuned_model\": null,\n", " \"suffix\": null,\n", " \"integrations\": [\n", " {\n", " \"project\": \"product-classifier\",\n", " \"name\": null,\n", " \"run_name\": null,\n", " \"url\": null\n", " }\n", " ],\n", " \"trained_tokens\": null,\n", " \"metadata\": {\n", " \"expected_duration_seconds\": 1000,\n", " \"cost\": 8.2,\n", " \"cost_currency\": \"EUR\",\n", " \"train_tokens_per_step\": 65536,\n", " \"train_tokens\": 16384000,\n", " \"data_tokens\": 2554120,\n", " \"estimated_start_time\": null\n", " },\n", " \"events\": [\n", " {\n", " \"name\": \"status-updated\",\n", " \"created_at\": 1744810479,\n", " \"data\": {\n", " \"status\": \"QUEUED\"\n", " }\n", " },\n", " {\n", " \"name\": \"status-updated\",\n", " \"created_at\": 1744810480,\n", " \"data\": {\n", " \"status\": \"VALIDATING\"\n", " }\n", " },\n", " {\n", " \"name\": \"status-updated\",\n", " \"created_at\": 1744810483,\n", " \"data\": {\n", " \"status\": \"VALIDATED\"\n", " }\n", " }\n", " ],\n", " \"checkpoints\": []\n", "}\n" ] } ], "source": [ "# Start the fine-tuning job\n", "client.fine_tuning.jobs.start(job_id=created_job.id)\n", "\n", "# Retrieve the job details again\n", "retrieved_job = client.fine_tuning.jobs.get(job_id=created_job.id)\n", "print(json.dumps(retrieved_job.model_dump(), indent=4))" ] }, { "cell_type": "markdown", "metadata": { "id": "eqW9NYAv0r5v" }, "source": [ "The job is now starting. Let's keep track of the status and plot the loss.\n", "\n", "For that, we highly recommend making use of our Weights and Biases integration, but we will also keep track of it directly in this notebook.\n", "\n", "### WANDB\n", "\n", "**Training:**\n", "\n", "\n", "\n", "**Eval/Validation:**\n", "\n", "\n", "\n", "**More:**\n", "\n", "\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "afQRQHHj0muJ", "outputId": "f74abaae-236b-4b5f-ebf8-3c4f7fc7c735", "cellView": "form" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SUCCESS\n" ] }, { "data": { "image/png": "", "text/plain": [ "<Figure size 1000x600 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# @title Loss Plot\n", "import pandas as pd\n", "import time\n", "import matplotlib.pyplot as plt\n", "from IPython.display import clear_output\n", "\n", "# Initialize DataFrames to store the metrics\n", "train_metrics_df = pd.DataFrame(columns=[\"Step Number\", \"Train Loss\"])\n", "valid_metrics_df = pd.DataFrame(columns=[\"Step Number\", \"Valid Loss\"])\n", "\n", "# Total training steps\n", "total_training_steps = retrieved_job.hyperparameters.training_steps\n", "\n", "# Wait for the job to complete\n", "while retrieved_job.status in [\"QUEUED\", \"RUNNING\"]:\n", " retrieved_job = client.fine_tuning.jobs.get(job_id=created_job.id)\n", "\n", " if retrieved_job.status == \"QUEUED\":\n", " time.sleep(5)\n", " continue\n", "\n", " # Clear the previous output (User Friendly)\n", " clear_output(wait=True)\n", " print(retrieved_job.status)\n", "\n", " # Extract metrics from all checkpoints\n", " for checkpoint in retrieved_job.checkpoints[::-1]:\n", " metrics = checkpoint.metrics\n", " step_number = checkpoint.step_number\n", "\n", " # Check if the step number is already in the DataFrame\n", " if (\n", " step_number\n", " not in train_metrics_df[\"Step Number\"]\n", " ):\n", " # Prepare the new row for train loss\n", " train_row = {\n", " \"Step Number\": step_number,\n", " \"Train Loss\": metrics.train_loss,\n", " }\n", "\n", " # Append the new train metrics to the DataFrame\n", " train_metrics_df = pd.concat(\n", " [train_metrics_df, pd.DataFrame([train_row])], ignore_index=True\n", " )\n", "\n", " # Prepare the new row for valid loss if available\n", " if metrics.valid_loss != 0:\n", " valid_row = {\n", " \"Step Number\": step_number,\n", " \"Valid Loss\": metrics.valid_loss,\n", " }\n", " # Append the new valid metrics to the DataFrame\n", " valid_metrics_df = pd.concat(\n", " [valid_metrics_df, pd.DataFrame([valid_row])], ignore_index=True\n", " )\n", "\n", " if len(retrieved_job.checkpoints) > 0:\n", " # Sort the DataFrames by step number\n", " train_metrics_df = train_metrics_df.sort_values(by=\"Step Number\")\n", " valid_metrics_df = valid_metrics_df.sort_values(by=\"Step Number\")\n", "\n", " # Plot the evolution of train loss and valid loss\n", " plt.figure(figsize=(10, 6))\n", "\n", " # Plot train loss\n", " plt.plot(\n", " train_metrics_df[\"Step Number\"],\n", " train_metrics_df[\"Train Loss\"],\n", " label=\"Train Loss\",\n", " linestyle=\"-\",\n", " )\n", "\n", " # Highlight start and end points of train loss\n", " plt.scatter(\n", " train_metrics_df.iloc[[0, -1]][\"Step Number\"],\n", " train_metrics_df.iloc[[0, -1]][\"Train Loss\"],\n", " color=\"blue\",\n", " zorder=5,\n", " )\n", "\n", " # Plot valid loss only if available\n", " if not valid_metrics_df.empty:\n", " plt.plot(\n", " valid_metrics_df[\"Step Number\"],\n", " valid_metrics_df[\"Valid Loss\"],\n", " label=\"Valid Loss\",\n", " linestyle=\"--\",\n", " )\n", "\n", " # Highlight start and end points of valid loss\n", " plt.scatter(\n", " valid_metrics_df.iloc[[0, -1]][\"Step Number\"],\n", " valid_metrics_df.iloc[[0, -1]][\"Valid Loss\"],\n", " color=\"orange\",\n", " zorder=5,\n", " )\n", "\n", " plt.xlabel(\"Step Number\")\n", " plt.ylabel(\"Loss\")\n", " plt.title(\"Train Loss and Valid Loss\")\n", " plt.legend()\n", " plt.grid(True)\n", " plt.show()\n", "\n", " time.sleep(1)" ] }, { "cell_type": "markdown", "metadata": { "id": "g-ZvhYEi1KQF" }, "source": [ "### Inference\n", "Our model is trained and ready for use! Let's test it on a sample from our test set!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "S5ifw9pG1J51", "outputId": "a25436ae-b9fc-42dd-a60a-4898a377102f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Text: Avena e nocciole cioccolato fondente\n", "Classifier Response: {\n", " \"id\": \"05f72b5b50bf4d4b9b58bf650dadfe3c\",\n", " \"model\": \"ft:classifier:ministral-3b-latest:8e2706f0:20250416:1905e7d8\",\n", " \"results\": [\n", " {\n", " \"food\": {\n", " \"scores\": {\n", " \"plant-based-foods-and-beverages\": 0.19011782109737396,\n", " \"meats-and-their-products\": 0.00018027107580564916,\n", " \"sweet-snacks\": 0.17583023011684418,\n", " \"snacks\": 0.13909316062927246,\n", " \"dairies\": 0.0024885653983801603,\n", " \"plant-based-foods\": 0.16261638700962067,\n", " \"cereals-and-potatoes\": 0.32593774795532227,\n", " \"beverages\": 0.003735779784619808\n", " }\n", " },\n", " \"country_label\": {\n", " \"scores\": {\n", " \"united-states\": 0.0010206910083070397,\n", " \"italy\": 0.9425649046897888,\n", " \"germany\": 0.0010048666736111045,\n", " \"france\": 0.02936650812625885,\n", " \"switzerland\": 0.020262280479073524,\n", " \"united-kingdom\": 0.0014394049067050219,\n", " \"spain\": 0.003071120008826256,\n", " \"belgium\": 0.0012702704407274723\n", " }\n", " }\n", " }\n", " ]\n", "}\n" ] } ], "source": [ "# Load the test samples\n", "with open(\"test_openfood_classification.jsonl\", \"r\") as f:\n", " test_samples = [json.loads(l) for l in f.readlines()]\n", "\n", "# Classify the first test sample\n", "classifier_response = client.classifiers.classify(\n", " model=retrieved_job.fine_tuned_model,\n", " inputs=[test_samples[0][\"text\"]],\n", ")\n", "print(\"Text:\", test_samples[0][\"text\"])\n", "print(\"Classifier Response:\", json.dumps(classifier_response.model_dump(), indent=4))" ] }, { "cell_type": "markdown", "metadata": { "id": "X8w2XmKOWewV" }, "source": [ "We can go even further and compare side by side normal prompting techniques with LLMs VS our new classifier, for this we will run the test set on multiple llms with structured outputs and compare the results to our classifier." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dTkda9Ja3aVs", "cellView": "form" }, "outputs": [], "source": [ "# @title Load Test and Set Response Type\n", "import json\n", "from pydantic import BaseModel\n", "from enum import Enum\n", "from typing import List\n", "\n", "# Load the JSONL file\n", "file_path = 'test_openfood_classification.jsonl'\n", "test_dataset = []\n", "\n", "with open(file_path, 'r') as file:\n", " for line in file:\n", " test_dataset.append(json.loads(line))\n", "\n", "# Define the enumerators for categories and countries\n", "Category = Enum('Category', {category.replace('-', '_'): category for category in all_category_labels})\n", "Country = Enum('Country', {country.replace('-', '_'): country for country in all_country_labels})\n", "\n", "# Define the Food model using the enumerators\n", "class Food(BaseModel):\n", " categories: List[Category]\n", " country: Country" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZhbeM8FZ29e4", "cellView": "form" }, "outputs": [], "source": [ "# @title Define the Classify Function\n", "import random\n", "\n", "instruction_prompt = \"\"\"Classify the following food product, you need to classify the country of the dish and the food categories it belongs to.\n", "\n", "Product Name: {}\"\"\"\n", "\n", "def classify(text: str, model: str) -> tuple:\n", " try:\n", " if model[\"type\"] == \"random\":\n", " possible_categories = list(all_category_labels)\n", " possible_countries = list(all_country_labels)\n", " predicted_categories = random.sample(possible_categories, random.randint(0, len(possible_categories)))\n", " predicted_country = random.choice(possible_countries)\n", " return predicted_categories, predicted_country\n", " elif model[\"type\"] == \"classifier\":\n", " classifier_response = client.classifiers.classify(\n", " model=model[\"model_id\"],\n", " inputs=[text],\n", " )\n", " results = classifier_response.results[0]\n", "\n", " # Extract all labels with their scores\n", " labels_with_scores = {label: results['food'].scores[label] for label in results['food'].scores.keys()}\n", "\n", " # Find the country with the highest score\n", " country_scores = results['country_label'].scores\n", " country_with_highest_score = max(country_scores, key=country_scores.get)\n", "\n", " return labels_with_scores, country_with_highest_score\n", " else:\n", " chat_response = client.chat.parse(\n", " model=model[\"model_id\"],\n", " messages=[\n", " {\n", " \"role\": \"user\",\n", " \"content\": instruction_prompt.format(text),\n", " },\n", " ],\n", " response_format=Food,\n", " max_tokens=512,\n", " temperature=0\n", " )\n", "\n", " return [c.value for c in chat_response.choices[0].message.parsed.categories], chat_response.choices[0].message.parsed.country.value\n", " except Exception as e:\n", " return {}, None" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gSRPWKMP5VIK", "outputId": "4da5a389-f6be-4745-cf77-fc20b7837660", "cellView": "form" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Running {'type': 'random', 'model_name': 'Random'} ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/300 [00:00<?, ?it/s]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 300/300 [00:00<00:00, 67671.89it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Category Scores: {'beverages': 0.09944751381215469, 'cereals-and-potatoes': 0.12637362637362637, 'dairies': 0.1393939393939394, 'meats-and-their-products': 0.09090909090909091, 'plant-based-foods': 0.14606741573033707, 'plant-based-foods-and-beverages': 0.14871794871794872, 'snacks': 0.14673913043478262, 'sweet-snacks': 0.17543859649122806}\n", "Average Category Score: 0.13413590773288847\n", "Country Score: 0.15\n", "\n", "Running {'type': 'classifier', 'model_name': 'Finetuned Classifier 3B', 'model_id': 'ft:classifier:ministral-3b-latest:8e2706f0:20250416:1905e7d8', 'thresholds': [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.6, 0.7, 0.8, 0.9]} ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 300/300 [02:08<00:00, 2.33it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Category Scores: {'beverages': 0.6346153846153846, 'cereals-and-potatoes': 0.7368421052631579, 'dairies': 0.7391304347826086, 'meats-and-their-products': 0.7878787878787878, 'plant-based-foods': 0.3132530120481928, 'plant-based-foods-and-beverages': 0.2823529411764706, 'snacks': 0.4430379746835443, 'sweet-snacks': 0.6470588235294118}\n", "Average Category Score: 0.5730211829971948\n", "Country Score: 0.7666666666666667\n", "Best Threshold: 0.25\n", "\n", "Running {'type': 'instruction', 'model_name': 'Ministral 3B', 'model_id': 'ministral-3b-latest'} ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 300/300 [08:09<00:00, 1.63s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Category Scores: {'beverages': 0.39080459770114945, 'cereals-and-potatoes': 0.3888888888888889, 'dairies': 0.6346153846153846, 'meats-and-their-products': 0.6086956521739131, 'plant-based-foods': 0.1282051282051282, 'plant-based-foods-and-beverages': 0.0, 'snacks': 0.2684563758389262, 'sweet-snacks': 0.35526315789473684}\n", "Average Category Score: 0.3468661481647659\n", "Country Score: 0.47\n", "\n", "Running {'type': 'instruction', 'model_name': 'Ministral 8B', 'model_id': 'ministral-8b-latest'} ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 300/300 [07:26<00:00, 1.49s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Category Scores: {'beverages': 0.5961538461538461, 'cereals-and-potatoes': 0.41818181818181815, 'dairies': 0.631578947368421, 'meats-and-their-products': 0.7428571428571429, 'plant-based-foods': 0.24731182795698925, 'plant-based-foods-and-beverages': 0.09523809523809523, 'snacks': 0.29365079365079366, 'sweet-snacks': 0.3220338983050847}\n", "Average Category Score: 0.4183757962140239\n", "Country Score: 0.5566666666666666\n", "\n", "Running {'type': 'instruction', 'model_name': 'Mistral Small 24B', 'model_id': 'mistral-small-latest'} ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 300/300 [03:05<00:00, 1.62it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Category Scores: {'beverages': 0.5555555555555556, 'cereals-and-potatoes': 0.37037037037037035, 'dairies': 0.7083333333333334, 'meats-and-their-products': 0.7631578947368421, 'plant-based-foods': 0.297029702970297, 'plant-based-foods-and-beverages': 0.15, 'snacks': 0.2926829268292683, 'sweet-snacks': 0.449438202247191}\n", "Average Category Score: 0.4483209982553572\n", "Country Score: 0.5566666666666666\n", "\n", "Running {'type': 'instruction', 'model_name': 'Mistral Large 123B', 'model_id': 'mistral-large-latest'} ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 300/300 [11:02<00:00, 2.21s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Category Scores: {'beverages': 0.7083333333333334, 'cereals-and-potatoes': 0.6, 'dairies': 0.7222222222222222, 'meats-and-their-products': 0.7837837837837838, 'plant-based-foods': 0.2912621359223301, 'plant-based-foods-and-beverages': 0.06153846153846154, 'snacks': 0.22666666666666666, 'sweet-snacks': 0.48484848484848486}\n", "Average Category Score: 0.4848318860394103\n", "Country Score: 0.6233333333333333\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# @title Run Evaluation\n", "from tqdm import tqdm\n", "import numpy as np\n", "\n", "# Number of samples to evaluate\n", "n_samples = 300\n", "\n", "def calculate_score(actual, predicted):\n", " \"\"\"\n", " Calculate the accuracy score for a single label.\n", "\n", " Parameters:\n", " - actual: List of actual labels for each entry.\n", " - predicted: List of predicted labels for each entry.\n", "\n", " Returns:\n", " - The accuracy score as a float.\n", " \"\"\"\n", " correct_predictions_count = 0\n", " total_predictions_count = 0\n", "\n", " for actual_labels, predicted_labels in zip(actual, predicted):\n", " if actual_labels or predicted_labels:\n", " total_predictions_count += 1\n", " if actual_labels and predicted_labels:\n", " correct_predictions_count += 1\n", "\n", " return correct_predictions_count / total_predictions_count if total_predictions_count > 0 else 0\n", "\n", "def calculate_country_score(actual, predicted):\n", " \"\"\"\n", " Calculate the accuracy score for country predictions.\n", "\n", " Parameters:\n", " - actual: List of actual country labels for each entry.\n", " - predicted: List of predicted country labels for each entry.\n", "\n", " Returns:\n", " - The accuracy score as a float.\n", " \"\"\"\n", " correct_predictions_count = sum(actual_country == predicted_country for actual_country, predicted_country in zip(actual, predicted))\n", " total_predictions = len(actual)\n", " accuracy_score = correct_predictions_count / total_predictions if total_predictions > 0 else 0\n", "\n", " return accuracy_score\n", "\n", "def evaluate_classifier(dataset, model):\n", " \"\"\"\n", " Evaluate the classifier model on the dataset.\n", "\n", " Parameters:\n", " - dataset: List of entries with text and labels.\n", " - model: Dictionary containing model details.\n", "\n", " Returns:\n", " - Category scores, country score, average category score, and best threshold (if applicable).\n", " \"\"\"\n", "\n", " # Initialize dictionaries to store actual and predicted labels for each category\n", " category_scores = {label: {\"actual\": [], \"predicted\": []} for label in all_category_labels}\n", " all_actual_countries = []\n", " all_predicted_countries = []\n", "\n", " # Store raw scores for classifier models\n", " raw_scores = []\n", "\n", " # Evaluate each entry in the dataset\n", " for entry in tqdm(dataset[:n_samples]):\n", " text = entry[\"text\"]\n", " actual_categories = [cat for cat in entry[\"labels\"][\"food\"]]\n", " actual_country = entry[\"labels\"][\"country_label\"]\n", "\n", " # Predict categories and country using the model\n", " if model[\"type\"] == \"classifier\":\n", " predicted_categories, predicted_country = classify(text, model)\n", " raw_scores.append((predicted_categories, predicted_country))\n", " else:\n", " predicted_categories, predicted_country = classify(text, model)\n", "\n", " # Accumulate actual and predicted categories and countries\n", " all_actual_countries.append(actual_country)\n", " all_predicted_countries.append(predicted_country)\n", "\n", " for label in all_category_labels:\n", " actual_label = [label] if label in actual_categories else []\n", " predicted_label = [label] if label in predicted_categories else []\n", " category_scores[label][\"actual\"].append(actual_label)\n", " category_scores[label][\"predicted\"].append(predicted_label)\n", "\n", " if model[\"type\"] == \"classifier\":\n", " best_threshold = None\n", " best_average_category_score = 0\n", "\n", " # Find the best threshold for the classifier model\n", " for threshold in model[\"thresholds\"]:\n", " actual_labels_per_category = {label: data[\"actual\"] for label, data in category_scores.items()}\n", " predicted_labels_per_category = {label: [] for label in all_category_labels}\n", "\n", " for raw_score in raw_scores:\n", " predicted_categories, _ = raw_score\n", " predicted_labels = [label for label, score in predicted_categories.items() if score > threshold]\n", " for label in all_category_labels:\n", " predicted_labels_per_category[label].append([label] if label in predicted_labels else [])\n", "\n", " category_score_results = {\n", " label: calculate_score(actual_labels_per_category[label], predicted_labels_per_category[label])\n", " for label in all_category_labels\n", " }\n", " average_category_score = np.mean(list(category_score_results.values()))\n", "\n", " if average_category_score > best_average_category_score:\n", " best_average_category_score = average_category_score\n", " best_threshold = threshold\n", "\n", " # Use the best threshold to compute final scores\n", " predicted_labels_per_category = {label: [] for label in all_category_labels}\n", "\n", " for raw_score in raw_scores:\n", " predicted_categories, _ = raw_score\n", " predicted_labels = [label for label, score in predicted_categories.items() if score > best_threshold]\n", " for label in all_category_labels:\n", " predicted_labels_per_category[label].append([label] if label in predicted_labels else [])\n", "\n", " category_score_results = {\n", " label: calculate_score(actual_labels_per_category[label], predicted_labels_per_category[label])\n", " for label in all_category_labels\n", " }\n", " country_score = calculate_country_score(all_actual_countries, all_predicted_countries)\n", " average_category_score = best_average_category_score\n", " return category_score_results, country_score, average_category_score, best_threshold\n", " else:\n", " # Prepare the actual and predicted labels for each category\n", " actual_labels_per_category = {label: data[\"actual\"] for label, data in category_scores.items()}\n", " predicted_labels_per_category = {label: data[\"predicted\"] for label, data in category_scores.items()}\n", "\n", " # Calculate score for each category and overall country score\n", " category_score_results = {\n", " label: calculate_score(actual_labels_per_category[label], predicted_labels_per_category[label])\n", " for label in all_category_labels\n", " }\n", " country_score = calculate_country_score(all_actual_countries, all_predicted_countries)\n", "\n", " # Calculate average category score\n", " average_category_score = np.mean(list(category_score_results.values()))\n", "\n", " return category_score_results, country_score, average_category_score, None\n", "\n", "# Dictionary to store model evaluation results\n", "model_results = {}\n", "\n", "# List of models to evaluate\n", "models = [\n", " {\"type\": \"random\", \"model_name\": \"Random\"},\n", " {\n", " \"type\": \"classifier\",\n", " \"model_name\": \"Finetuned Classifier 3B\",\n", " \"model_id\": retrieved_job.fine_tuned_model,\n", " \"thresholds\": [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.6, 0.7, 0.8, 0.9] # the thresholds to try our multilabels on, depending on your data, training and use case, you may want to change your threshold to get the best score out of your model for your specific metric\n", " },\n", " {\"type\": \"instruction\", \"model_name\": \"Ministral 3B\", \"model_id\": \"ministral-3b-latest\"},\n", " {\"type\": \"instruction\", \"model_name\": \"Ministral 8B\", \"model_id\": \"ministral-8b-latest\"},\n", " {\"type\": \"instruction\", \"model_name\": \"Mistral Small 24B\", \"model_id\": \"mistral-small-latest\"},\n", " {\"type\": \"instruction\", \"model_name\": \"Mistral Large 123B\", \"model_id\": \"mistral-large-latest\"},\n", "]\n", "\n", "# Evaluate each model\n", "for model in models:\n", " print(\"\\nRunning\", model, \"...\")\n", " category_scores, country_score, average_category_score, best_threshold = evaluate_classifier(test_dataset, model)\n", "\n", " result = {\n", " \"category_scores\": category_scores,\n", " \"average_category_score\": average_category_score,\n", " \"country_score\": country_score,\n", " }\n", "\n", " model_name = model['model_name']\n", " if model[\"type\"] == \"classifier\":\n", " model_name = f\"{model['model_name']} Threshold: {best_threshold}\"\n", "\n", " model_results[model_name] = result\n", "\n", " print(f\"Category Scores: {category_scores}\")\n", " print(f\"Average Category Score: {average_category_score}\")\n", " print(f\"Country Score: {country_score}\")\n", " if model[\"type\"] == \"classifier\":\n", " print(f\"Best Threshold: {best_threshold}\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "5W98pgFJ1-Lt", "outputId": "3679d25c-083d-49b6-e137-f6bd2c2d2b6e", "cellView": "form" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "<Figure size 2400x1600 with 2 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# @title Plot Results\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib import cm\n", "\n", "def plot_bar_chart(ax, data, labels, title, xlabel, ylabel, colors, bar_width, best_model=None):\n", " index = np.arange(len(labels))\n", " total_bars = len(labels) * len(data)\n", "\n", " group_spacing = 0.4\n", " group_width = bar_width * len(data)\n", "\n", " # Get the scores for the \"random\" model\n", " random_scores = data.get('Random', [0] * len(labels))\n", "\n", " # Get the scores for the best model\n", " best_model_scores = data.get(best_model, [0] * len(labels)) if best_model else [0] * len(labels)\n", "\n", " for i, (model, metrics) in enumerate(data.items()):\n", " positions = index + i * bar_width - group_width / 2 + group_spacing / 2\n", " color = 'red' if model == 'Random' else colors[i]\n", " bars = ax.bar(positions, metrics, width=bar_width, label=model, color=color, zorder=2)\n", "\n", " # Check if the model is the best model\n", " if model == best_model:\n", " for bar in bars:\n", " bar.set_hatch('//')\n", " bar.set_edgecolor('#FF8C00')\n", " height = bar.get_height()\n", " ax.annotate(f'{int(100 * height)}%',\n", " xy=(bar.get_x() + bar.get_width() / 2, height),\n", " xytext=(0, 5),\n", " textcoords=\"offset points\",\n", " ha='center', va='bottom',\n", " color='orange', fontsize=10)\n", "\n", " # Annotate the random model's bars\n", " if model == 'Random':\n", " for bar in bars:\n", " height = bar.get_height()\n", " ax.annotate(f'{int(100 * height)}%',\n", " xy=(bar.get_x() + bar.get_width() / 2, height),\n", " xytext=(0, 5), # 5 points vertical offset\n", " textcoords=\"offset points\",\n", " ha='center', va='bottom',\n", " color='red', fontsize=10)\n", "\n", " # Add a red horizontal line for the \"random\" model's scores\n", " for idx, score in enumerate(random_scores):\n", " ax.hlines(y=score, xmin=index[idx] - group_width / 2 + bar_width,\n", " xmax=index[idx] + group_width / 2 + bar_width, color='red', linestyle=':', linewidth=0.8, zorder=3)\n", "\n", " # Add a green horizontal line for the \"best\" model's scores\n", " for idx, score in enumerate(best_model_scores):\n", " ax.hlines(y=score, xmin=index[idx] - group_width / 2 + bar_width,\n", " xmax=index[idx] + group_width / 2 + bar_width, color='orange', linestyle=':', linewidth=0.8, zorder=3)\n", "\n", " ax.set_title(title)\n", " ax.set_xlabel(xlabel)\n", " ax.set_ylabel(ylabel)\n", " ax.set_xticks(index + group_spacing / 2)\n", " ax.set_xticklabels(labels, rotation=0, ha='center', fontsize=6)\n", " ax.set_ylim(0, 1.19)\n", " ax.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')\n", "\n", " # Add a light grid in the background\n", " ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.7, zorder=1)\n", "\n", "def plot_score_metrics(model_results, n_samples):\n", " # Find best model\n", " best_model = max(\n", " (model for model in model_results.keys()),\n", " key=lambda model: model_results[model]['average_category_score'],\n", " default=None\n", " )\n", "\n", " models = [m for m in model_results.keys()]\n", " colors = cm.YlOrBr(np.linspace(0.4, 0.6, len(models)))\n", "\n", " # Create a figure with a 2x1 grid of subplots\n", " fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(24, 16))\n", " fig.suptitle(f'Scores (n_samples = {n_samples})')\n", "\n", " # Categories Score\n", " categories_score = {model: [model_results[model]['category_scores'][category] for category in model_results[model]['category_scores']] for model in models}\n", " plot_bar_chart(axes[0], categories_score, list(model_results[models[0]]['category_scores'].keys()),\n", " 'Category Scores', 'Category', 'Score', colors, 0.14, best_model)\n", "\n", " # Average Category Score and Countries Score as subplots in the second row\n", " average_category_score = {model: [model_results[model]['average_category_score']] for model in models}\n", " countries_score = {model: [model_results[model]['country_score']] for model in models}\n", "\n", " # Combine the two metrics into one subplot\n", " combined_metrics = {model: average_category_score[model] + countries_score[model] for model in models}\n", " plot_bar_chart(axes[1], combined_metrics, ['Average Category Scores', 'Country Scores'],\n", " 'Average Category and Country Scores', 'Metric', 'Score', colors, 0.14, best_model)\n", "\n", " plt.show()\n", "\n", "plot_score_metrics(model_results, n_samples)" ] }, { "cell_type": "markdown", "metadata": { "id": "vvE5fQsbW1IQ" }, "source": [ "For this specific use case, most llms are struggling, this can be due to various reasons, bad prompting, small models, too specific use case...\n", "\n", "However, our finetuned classifier performs extremely well, outperforming all other models by a decent margin! Making it not only better, but also more efficient and cheaper, as a considerably smaller model compared to its older brothers." ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 0 }