Reach out
← Back to Cookbook

code embedding

Details

File: mistral/embeddings/code_embedding.ipynb

Type: Jupyter Notebook

Use Cases: Code embedding

Content

Notebook content (JSON format):

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Jih5WuYmeC0d"
   },
   "source": [
    "# Mistral Code Embedding and Retrieval Evaluation\n",
    "\n",
    "This notebook demonstrates a pipeline for code embedding, chunking, indexing, retrieval, and evaluation using the Mistral API, FAISS, and the SWE-bench Lite dataset.  \n",
    "\n",
    "It uses the Mistral code embedding model `codestral-embed` to generate code embeddings and FAISS for fast similarity search. The workflow includes flattening repository structures, chunking code files into smaller segments, and generating embeddings for each chunk. These embeddings are indexed to enable efficient retrieval of relevant code snippets in response to user queries. The notebook evaluates retrieval performance on the SWE-bench Lite dataset, using recall metrics to measure effectiveness. This methodology is especially valuable for applications such as code search, code comprehension, and automated software maintenance.\n",
    "\n",
    "\n",
    "## Environment Setup\n",
    "Install required packages for code embedding, retrieval, and dataset handling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m173.4/173.4 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m31.3/31.3 MB\u001b[0m \u001b[31m31.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m372.3/372.3 kB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.5/6.5 MB\u001b[0m \u001b[31m32.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "torch 2.6.0+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cublas-cu12 12.5.3.2 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-cuda-cupti-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-cupti-cu12 12.5.82 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-nvrtc-cu12 12.5.82 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-cuda-runtime-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-runtime-cu12 12.5.82 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-cudnn-cu12==9.1.0.70; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cudnn-cu12 9.3.0.75 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-cufft-cu12==11.2.1.3; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cufft-cu12 11.2.3.61 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-curand-cu12==10.3.5.147; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-curand-cu12 10.3.6.82 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-cusolver-cu12==11.6.1.9; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cusolver-cu12 11.6.3.83 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-cusparse-cu12==12.3.1.170; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cusparse-cu12 12.5.1.3 which is incompatible.\n",
      "torch 2.6.0+cu124 requires nvidia-nvjitlink-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-nvjitlink-cu12 12.5.82 which is incompatible.\n",
      "gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2023.9.2 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install -q faiss-cpu mistralai mistral-common datasets fsspec==2023.9.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SeV0QmlLeU_x"
   },
   "source": [
    "## Imports and Tokenizer Initialization\n",
    "\n",
    "Import necessary libraries and initialize the tokenizer for code embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: \n",
      "Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.\n",
      "You are not authenticated with the Hugging Face Hub in this notebook.\n",
      "If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2483f1ff84244b08af95921f61e0335a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tekken.json:   0%|          | 0.00/14.8M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.11/dist-packages/mistral_common/tokens/tokenizers/tekken.py:184: FutureWarning: Special tokens not found in /root/.cache/huggingface/hub/models--mistralai--Mistral-Small-3.1-24B-Base-2503/snapshots/db7c968753c07380364d963090b5cf8cc131a0c3/tekken.json and default to ({'rank': 0, 'token_str': <SpecialTokens.unk: '<unk>'>, 'is_control': True}, {'rank': 1, 'token_str': <SpecialTokens.bos: '<s>'>, 'is_control': True}, {'rank': 2, 'token_str': <SpecialTokens.eos: '</s>'>, 'is_control': True}, {'rank': 3, 'token_str': <SpecialTokens.begin_inst: '[INST]'>, 'is_control': True}, {'rank': 4, 'token_str': <SpecialTokens.end_inst: '[/INST]'>, 'is_control': True}, {'rank': 5, 'token_str': <SpecialTokens.begin_tools: '[AVAILABLE_TOOLS]'>, 'is_control': True}, {'rank': 6, 'token_str': <SpecialTokens.end_tools: '[/AVAILABLE_TOOLS]'>, 'is_control': True}, {'rank': 7, 'token_str': <SpecialTokens.begin_tool_results: '[TOOL_RESULTS]'>, 'is_control': True}, {'rank': 8, 'token_str': <SpecialTokens.end_tool_results: '[/TOOL_RESULTS]'>, 'is_control': True}, {'rank': 9, 'token_str': <SpecialTokens.tool_calls: '[TOOL_CALLS]'>, 'is_control': True}, {'rank': 10, 'token_str': <SpecialTokens.img: '[IMG]'>, 'is_control': True}, {'rank': 11, 'token_str': <SpecialTokens.pad: '<pad>'>, 'is_control': True}, {'rank': 12, 'token_str': <SpecialTokens.img_break: '[IMG_BREAK]'>, 'is_control': True}, {'rank': 13, 'token_str': <SpecialTokens.img_end: '[IMG_END]'>, 'is_control': True}, {'rank': 14, 'token_str': <SpecialTokens.prefix: '[PREFIX]'>, 'is_control': True}, {'rank': 15, 'token_str': <SpecialTokens.middle: '[MIDDLE]'>, 'is_control': True}, {'rank': 16, 'token_str': <SpecialTokens.suffix: '[SUFFIX]'>, 'is_control': True}, {'rank': 17, 'token_str': <SpecialTokens.begin_system: '[SYSTEM_PROMPT]'>, 'is_control': True}, {'rank': 18, 'token_str': <SpecialTokens.end_system: '[/SYSTEM_PROMPT]'>, 'is_control': True}, {'rank': 19, 'token_str': <SpecialTokens.begin_tool_content: '[TOOL_CONTENT]'>, 'is_control': True}). This behavior will be deprecated going forward. Please update your tokenizer file and include all special tokens you need.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import os\n",
    "import pickle\n",
    "from pathlib import Path\n",
    "from typing import Dict, List, Tuple, Set, Optional, Any\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from mistralai import Mistral\n",
    "from langchain.text_splitter import Language, RecursiveCharacterTextSplitter\n",
    "import faiss\n",
    "from collections import defaultdict\n",
    "import re\n",
    "from getpass import getpass\n",
    "\n",
    "from huggingface_hub import hf_hub_download\n",
    "from mistral_common.tokens.tokenizers.tekken import Tekkenizer\n",
    "\n",
    "# Download tokenizer from Hugging Face\n",
    "repo_id = \"mistralai/Mistral-Small-3.1-24B-Base-2503\"\n",
    "# adjust filename if the repo uses a different .json name\n",
    "tk_path = hf_hub_download(repo_id, filename=\"tekken.json\")\n",
    "\n",
    "tokenizer = Tekkenizer.from_file(tk_path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pnIsjBGeebpj"
   },
   "source": [
    "## API Key Setup\n",
    "Set up your Mistral API key for authentication."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Enter your MISTRAL_API_KEY: ··········\n"
     ]
    }
   ],
   "source": [
    "api_key = getpass(\"Enter your MISTRAL_API_KEY: \").strip()\n",
    "os.environ[\"MISTRAL_API_KEY\"] = api_key\n",
    "\n",
    "\n",
    "client = Mistral(api_key=api_key.strip())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hdcaR7Akeryn"
   },
   "source": [
    "## Embedding and Chunking Configuration\n",
    "Define parameters for code embedding and chunking."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# embeddings\n",
    "TOP_K = 5\n",
    "EMBED_MODEL = \"codestral-embed\"\n",
    "MAX_BATCH_SIZE = 128  # for embedding\n",
    "MAX_TOTAL_TOKENS = 16384  # for embedding\n",
    "MAX_SEQUENCE_LENGTH = 8192  # for embedding\n",
    "\n",
    "# chunking\n",
    "DO_CHUNKING = True\n",
    "CHUNK_SIZE = 3000\n",
    "CHUNK_OVERLAP = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "a5v7NZ48H284"
   },
   "source": [
    "![image (2).png]()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_qQkWCr2H_rK"
   },
   "source": [
    "In our experiments, we find that chunking with small chunk size (3000 characters, ~512 tokens) and overlap (1000 characters), leads to much better retrieval for RAG.\n",
    "\n",
    "\n",
    "## Download and Prepare Repository Structures\n",
    "\n",
    "Download and extract repository structures for the SWE-bench Lite dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading...\n",
      "From (original): https://drive.google.com/uc?id=1wG1CcfVHi-70FoAd5wPI59WdI4g1LkpS\n",
      "From (redirected): https://drive.google.com/uc?id=1wG1CcfVHi-70FoAd5wPI59WdI4g1LkpS&confirm=t&uuid=ff0d2b10-c817-4e53-abd5-44d7755dd926\n",
      "To: /content/min_swebench_repo_structure.zip\n",
      "100%|██████████| 200M/200M [00:01<00:00, 139MB/s]\n"
     ]
    }
   ],
   "source": [
    "import gdown\n",
    "import zipfile\n",
    "\n",
    "USE_MIN_SWEBENCH = True\n",
    "\n",
    "if not USE_MIN_SWEBENCH:\n",
    "  # for all 300 repo_structures from Agentless for swebench lite - https://github.com/OpenAutoCoder/Agentless/blob/main/README_swebench.md#-setup\n",
    "  zip_url = \"https://drive.google.com/uc?id=15-4XjTmY48ystrsc_xcvtOkMs3Fx8RoW\"\n",
    "  zip_path = \"/content/swebench_repo_structure.zip\"\n",
    "  repo_structures_path = \"/content/repo_structures/repo_structures\"\n",
    "\n",
    "else:\n",
    "  # subset of 33 tasks from above for faster download\n",
    "  zip_url = \"https://drive.google.com/uc?id=1wG1CcfVHi-70FoAd5wPI59WdI4g1LkpS\"\n",
    "  zip_path = \"/content/min_swebench_repo_structure.zip\"\n",
    "  repo_structures_path = \"/content/min_repo_structures/repo_structures\"\n",
    "\n",
    "if not os.path.exists(repo_structures_path):\n",
    "  gdown.download(zip_url, zip_path, quiet=False)\n",
    "\n",
    "  with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
    "      zip_ref.extractall(\"/content/\")\n",
    "\n",
    "# Set paths\n",
    "index_dir: str = \"/content/swebench_indexes\"\n",
    "results_file: str = \"/content/swebench_results.json\"\n",
    "\n",
    "if DO_CHUNKING:\n",
    "    # make swebench_indexes to swebench_indexes_chunked_<chunk_size>_<chunk_overlap>\n",
    "    index_dir = f\"{index_dir}_chunked_size_{CHUNK_SIZE}_overlap_{CHUNK_OVERLAP}\"\n",
    "    Path(index_dir).mkdir(exist_ok=True)\n",
    "\n",
    "# Create index directory\n",
    "Path(index_dir).mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fg1oWFuqfPp1"
   },
   "source": [
    "## Utility Functions for Data Processing\n",
    "\n",
    "Define helper functions for flattening repository structures, chunking code, formatting documents, and extracting file paths from patches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_repo_structure(\n",
    "    structure: Dict[str, Any], current_path: str = \"\"\n",
    ") -> Dict[str, str]:\n",
    "    \"\"\"\n",
    "    Recursively flatten nested repo structure into file paths and contents.\n",
    "    Only keeps non-empty Python files.\n",
    "    \"\"\"\n",
    "    flattened = {}\n",
    "\n",
    "    for key, value in structure.items():\n",
    "        # Build the path\n",
    "        path = os.path.join(current_path, key) if current_path else key\n",
    "\n",
    "        if isinstance(value, dict):\n",
    "            # Check if this is a file with content\n",
    "            if \"text\" in value and isinstance(value[\"text\"], list):\n",
    "                # This is a file with content\n",
    "                content = \"\\n\".join(value[\"text\"])\n",
    "\n",
    "                # Only keep Python files with non-empty content\n",
    "                if path.endswith(\".py\") and content.strip():\n",
    "                    flattened[path] = content\n",
    "            else:\n",
    "                # This is a directory, recurse\n",
    "                flattened.update(flatten_repo_structure(value, path))\n",
    "\n",
    "    return flattened\n",
    "\n",
    "\n",
    "def load_repository_structure(\n",
    "    repo_structures_path: str, instance_id: str\n",
    ") -> Dict[str, str]:\n",
    "    \"\"\"Load and flatten repository structure from JSON file.\"\"\"\n",
    "    json_path = Path(repo_structures_path) / f\"{instance_id}.json\"\n",
    "\n",
    "    if not json_path.exists():\n",
    "        print(f\"Warning: Repository structure not found for {instance_id}\")\n",
    "        return {}\n",
    "\n",
    "    with open(json_path, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    # The structure is usually under a \"structure\" key with the repo name\n",
    "    if \"structure\" in data:\n",
    "        structure = data[\"structure\"]\n",
    "        # Get the first (and usually only) key which is the repo name\n",
    "        # repo_name = list(structure.keys())[0] if structure else \"\"\n",
    "        # if repo_name:\n",
    "        #     structure = structure[repo_name]\n",
    "\n",
    "        # Flatten the structure\n",
    "        return flatten_repo_structure(structure)\n",
    "\n",
    "    # Fallback: assume the entire JSON is the structure\n",
    "    return flatten_repo_structure(data)\n",
    "\n",
    "\n",
    "def get_language_from_path(path: str) -> Optional[Language]:\n",
    "    \"\"\"Get language from file extension.\"\"\"\n",
    "    EXTENSION_TO_LANGUAGE = {\n",
    "        \".cpp\": Language.CPP,\n",
    "        \".cc\": Language.CPP,\n",
    "        \".cxx\": Language.CPP,\n",
    "        \".c++\": Language.CPP,\n",
    "        \".go\": Language.GO,\n",
    "        \".java\": Language.JAVA,\n",
    "        \".kt\": Language.KOTLIN,\n",
    "        \".kts\": Language.KOTLIN,\n",
    "        \".js\": Language.JS,\n",
    "        \".mjs\": Language.JS,\n",
    "        \".ts\": Language.TS,\n",
    "        \".php\": Language.PHP,\n",
    "        \".proto\": Language.PROTO,\n",
    "        \".py\": Language.PYTHON,\n",
    "        \".pyw\": Language.PYTHON,\n",
    "        \".rst\": Language.RST,\n",
    "        \".rb\": Language.RUBY,\n",
    "        \".rs\": Language.RUST,\n",
    "        \".scala\": Language.SCALA,\n",
    "        \".swift\": Language.SWIFT,\n",
    "        \".md\": Language.MARKDOWN,\n",
    "        \".markdown\": Language.MARKDOWN,\n",
    "        \".tex\": Language.LATEX,\n",
    "        \".html\": Language.HTML,\n",
    "        \".htm\": Language.HTML,\n",
    "        \".sol\": Language.SOL,\n",
    "        \".cs\": Language.CSHARP,\n",
    "        \".cbl\": Language.COBOL,\n",
    "        \".cob\": Language.COBOL,\n",
    "        \".c\": Language.C,\n",
    "        \".h\": Language.C,\n",
    "        \".lua\": Language.LUA,\n",
    "        \".pl\": Language.PERL,\n",
    "        \".pm\": Language.PERL,\n",
    "        \".hs\": Language.HASKELL,\n",
    "        \".ex\": Language.ELIXIR,\n",
    "        \".exs\": Language.ELIXIR,\n",
    "        \".ps1\": Language.POWERSHELL,\n",
    "    }\n",
    "    _, ext = os.path.splitext(path)\n",
    "    return EXTENSION_TO_LANGUAGE.get(ext.lower())\n",
    "\n",
    "\n",
    "def chunk_corpus(\n",
    "    corpus: Dict[str, Dict[str, str]], chunk_size: int, chunk_overlap: int\n",
    ") -> Dict[str, Dict[str, str]]:\n",
    "    \"\"\"Chunk the corpus using language-specific splitters.\"\"\"\n",
    "    new_corpus = {}\n",
    "\n",
    "    for orig_id, doc in corpus.items():\n",
    "        title = doc.get(\"title\", \"\").strip()\n",
    "        text = doc.get(\"text\", \"\").strip()\n",
    "\n",
    "        # Skip empty texts\n",
    "        if not text:\n",
    "            continue\n",
    "\n",
    "        # Get language-specific splitter\n",
    "        language = get_language_from_path(title)\n",
    "        if language:\n",
    "            try:\n",
    "                splitter = RecursiveCharacterTextSplitter.from_language(\n",
    "                    language=language,\n",
    "                    chunk_size=chunk_size,\n",
    "                    chunk_overlap=chunk_overlap,\n",
    "                )\n",
    "            except:\n",
    "                # Fallback to generic splitter\n",
    "                splitter = RecursiveCharacterTextSplitter(\n",
    "                    chunk_size=chunk_size,\n",
    "                    chunk_overlap=chunk_overlap,\n",
    "                )\n",
    "        else:\n",
    "            splitter = RecursiveCharacterTextSplitter(\n",
    "                chunk_size=chunk_size,\n",
    "                chunk_overlap=chunk_overlap,\n",
    "            )\n",
    "\n",
    "        # Split only the text\n",
    "        chunks = splitter.split_text(text)\n",
    "        if not chunks:\n",
    "            new_corpus[orig_id] = doc\n",
    "            continue\n",
    "\n",
    "        for i, chunk_text in enumerate(chunks):\n",
    "            chunk_id = f\"{orig_id}_<chunk>_{i}\"\n",
    "            new_corpus[chunk_id] = {\n",
    "                \"title\": title,\n",
    "                \"text\": chunk_text,\n",
    "            }\n",
    "\n",
    "    return new_corpus\n",
    "\n",
    "\n",
    "def format_doc(doc: Dict[str, str]) -> str:\n",
    "    \"\"\"Format document for embedding.\"\"\"\n",
    "    assert \"title\" in doc and \"text\" in doc\n",
    "    title = doc.get(\"title\", \"\").strip()\n",
    "    text = doc.get(\"text\", \"\").strip()\n",
    "    return f\"{title}\\n{text}\" if title else text\n",
    "\n",
    "\n",
    "def get_embeddings_batch(texts: List[str]) -> List[List[float]]:\n",
    "    \"\"\"Get embeddings for a batch of texts using Mistral API with token limits.\"\"\"\n",
    "    if not texts:\n",
    "        return []\n",
    "\n",
    "\n",
    "    # Filter texts by token count and prepare batches\n",
    "    valid_texts = []\n",
    "    for text in texts:\n",
    "        tokens = tokenizer.encode(text, bos=False, eos=False)\n",
    "        if len(tokens) <= MAX_SEQUENCE_LENGTH:  # Max tokens per individual text\n",
    "            valid_texts.append(text)\n",
    "        else:\n",
    "            # Truncate text instead of skipping\n",
    "            truncated_tokens = tokens[:MAX_SEQUENCE_LENGTH]\n",
    "            truncated_text = tokenizer.decode(truncated_tokens)\n",
    "            valid_texts.append(truncated_text)\n",
    "            print(\n",
    "                f\"Truncated text from {len(tokens)} to {len(truncated_tokens)} tokens\"\n",
    "            )\n",
    "\n",
    "    if not valid_texts:\n",
    "        return []\n",
    "\n",
    "    # Create batches respecting token and size limits\n",
    "    batches = []\n",
    "    current_batch = []\n",
    "    current_batch_tokens = 0\n",
    "\n",
    "    for text in valid_texts:\n",
    "        tokens = tokenizer.encode(text, bos=False, eos=False)\n",
    "        text_token_count = len(tokens)\n",
    "\n",
    "        # Check if adding this text would exceed limits\n",
    "        if (len(current_batch) >= MAX_BATCH_SIZE or  # Max batch size\n",
    "            current_batch_tokens + text_token_count > MAX_TOTAL_TOKENS):  # Max total tokens\n",
    "\n",
    "            if current_batch:\n",
    "                batches.append(current_batch)\n",
    "                current_batch = []\n",
    "                current_batch_tokens = 0\n",
    "\n",
    "        current_batch.append(text)\n",
    "        current_batch_tokens += text_token_count\n",
    "\n",
    "    # Add the last batch if it's not empty\n",
    "    if current_batch:\n",
    "        batches.append(current_batch)\n",
    "\n",
    "    # Process batches\n",
    "    all_embeddings = []\n",
    "    for batch in tqdm(batches, desc=\"Processing embedding batches\"):\n",
    "        try:\n",
    "            response = client.embeddings.create(\n",
    "                model=EMBED_MODEL,\n",
    "                inputs=batch,\n",
    "            )\n",
    "            batch_embeddings = [data.embedding for data in response.data]\n",
    "            all_embeddings.extend(batch_embeddings)\n",
    "        except Exception as e:\n",
    "            print(f\"Error getting embeddings for batch: {e}\")\n",
    "            # Add empty embeddings for failed batch\n",
    "            all_embeddings.extend([[] for _ in batch])\n",
    "\n",
    "    return all_embeddings\n",
    "\n",
    "\n",
    "def parse_patch_for_files(patch: str) -> Set[str]:\n",
    "    \"\"\"Extract file paths from a patch.\"\"\"\n",
    "    files = set()\n",
    "\n",
    "    # Look for diff headers\n",
    "    diff_pattern = r\"^diff --git a/(.*?) b/(.*?)$\"\n",
    "    for line in patch.split(\"\\n\"):\n",
    "        match = re.match(diff_pattern, line)\n",
    "        if match:\n",
    "            # Usually both paths are the same, but take both just in case\n",
    "            files.add(match.group(1))\n",
    "            files.add(match.group(2))\n",
    "\n",
    "    # Also look for --- and +++ lines\n",
    "    file_pattern = r\"^[\\-\\+]{3} [ab]/(.*?)(?:\\s|$)\"\n",
    "    for line in patch.split(\"\\n\"):\n",
    "        match = re.match(file_pattern, line)\n",
    "        if match and match.group(1) != \"/dev/null\":\n",
    "            files.add(match.group(1))\n",
    "\n",
    "    return files\n",
    "\n",
    "\n",
    "def load_swebench_lite():\n",
    "    \"\"\"Load SWE-bench Lite dataset and extract ground truth.\"\"\"\n",
    "    print(\"Loading SWE-bench Lite dataset...\")\n",
    "    dataset = load_dataset(\"princeton-nlp/SWE-bench_Lite\", split=\"test\", download_mode=\"force_redownload\")\n",
    "\n",
    "    ground_truth_dict = {}\n",
    "    instances = []\n",
    "\n",
    "    for item in dataset:\n",
    "        instance_id = item[\"instance_id\"]\n",
    "        problem_statement = item[\"problem_statement\"]\n",
    "        patch = item[\"patch\"]\n",
    "\n",
    "        # Extract files from patch\n",
    "        files_changed = parse_patch_for_files(patch)\n",
    "\n",
    "        ground_truth_dict[instance_id] = list(files_changed)\n",
    "        instances.append(\n",
    "            {\n",
    "                \"instance_id\": instance_id,\n",
    "                \"problem_statement\": problem_statement,\n",
    "                \"patch\": patch,\n",
    "                \"files_changed\": list(files_changed),\n",
    "            }\n",
    "        )\n",
    "\n",
    "    return instances, ground_truth_dict\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "J-RHQ2xjf_vb"
   },
   "source": [
    "## Embedding, Indexing, and Retrieval Functions\n",
    "\n",
    "Functions for generating embeddings, building FAISS indexes, retrieving relevant files, and evaluating recall."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def index_repository(repo_content: Dict[str, str], instance_id: str, index_dir: str):\n",
    "    \"\"\"Index a repository and save the index.\"\"\"\n",
    "    print(f\"\\nIndexing repository for {instance_id}...\")\n",
    "    print(f\"Found {len(repo_content)} Python files\")\n",
    "\n",
    "    if not repo_content:\n",
    "        print(f\"No Python files found for {instance_id}\")\n",
    "        return\n",
    "\n",
    "    # Create corpus format expected by chunking function\n",
    "    corpus = {}\n",
    "    for file_path, content in repo_content.items():\n",
    "        corpus[file_path] = {\"title\": file_path, \"text\": content}\n",
    "\n",
    "    # Chunk the corpus only if DO_CHUNKING is True\n",
    "    if DO_CHUNKING:\n",
    "        print(f\"Chunking {len(corpus)} files...\")\n",
    "        chunked_corpus = chunk_corpus(corpus, CHUNK_SIZE, CHUNK_OVERLAP)\n",
    "        print(f\"Created {len(chunked_corpus)} chunks from {len(corpus)} files (size increase: {len(chunked_corpus)/len(corpus):.1f}x)\")\n",
    "    else:\n",
    "        print(\"Skipping chunking (DO_CHUNKING=False)\")\n",
    "        chunked_corpus = corpus\n",
    "\n",
    "    if not chunked_corpus:\n",
    "        print(f\"No chunks created for {instance_id}\")\n",
    "        return\n",
    "\n",
    "    # Prepare texts for embedding\n",
    "    texts_to_embed = []\n",
    "    chunk_ids = []\n",
    "    chunk_to_file = {}  # Map chunk_id to original file path\n",
    "\n",
    "    print(\"Preparing texts for embedding...\")\n",
    "    for chunk_id, chunk_doc in chunked_corpus.items():\n",
    "        text = format_doc(chunk_doc)\n",
    "        texts_to_embed.append(text)\n",
    "        chunk_ids.append(chunk_id)\n",
    "\n",
    "        # Extract original file path from chunk_id\n",
    "        if DO_CHUNKING and \"_<chunk>_\" in chunk_id:\n",
    "            original_file = chunk_id.split(\"_<chunk>_\")[0]\n",
    "        else:\n",
    "            original_file = chunk_id\n",
    "        chunk_to_file[chunk_id] = original_file\n",
    "\n",
    "    # Get embeddings in batches\n",
    "    print(\"Getting embeddings...\")\n",
    "    all_embeddings = get_embeddings_batch(texts_to_embed)\n",
    "\n",
    "    if not all_embeddings or len(all_embeddings) != len(texts_to_embed):\n",
    "        print(f\"Failed to get embeddings for {instance_id}\")\n",
    "        return\n",
    "\n",
    "    # Convert to numpy array\n",
    "    print(\"Creating FAISS index...\")\n",
    "    embeddings_array = np.array(all_embeddings, dtype=np.float32)\n",
    "\n",
    "    # Create FAISS index\n",
    "    dimension = embeddings_array.shape[1]\n",
    "    index = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity\n",
    "\n",
    "    # Normalize for cosine similarity\n",
    "    faiss.normalize_L2(embeddings_array)\n",
    "    index.add(embeddings_array)\n",
    "\n",
    "    # Save index and metadata\n",
    "    instance_index_dir = Path(index_dir) / instance_id\n",
    "    instance_index_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    # Save FAISS index\n",
    "    faiss.write_index(index, str(instance_index_dir / \"index.faiss\"))\n",
    "\n",
    "    # Save metadata\n",
    "    metadata = {\n",
    "        \"chunk_ids\": chunk_ids,\n",
    "        \"chunk_to_file\": chunk_to_file,\n",
    "        \"dimension\": dimension,\n",
    "        \"num_chunks\": len(chunk_ids),\n",
    "        \"num_files\": len(corpus),\n",
    "    }\n",
    "\n",
    "    with open(instance_index_dir / \"metadata.pkl\", \"wb\") as f:\n",
    "        pickle.dump(metadata, f)\n",
    "\n",
    "    print(f\"Saved index for {instance_id} with {len(chunk_ids)} chunks\")\n",
    "\n",
    "\n",
    "def retrieve_files(\n",
    "    query: str, instance_id: str, index_dir: str, top_k: int = 5\n",
    ") -> List[Tuple[str, float]]:\n",
    "    \"\"\"Retrieve top-k files for a query using max pooling over chunks.\"\"\"\n",
    "    instance_index_dir = Path(index_dir) / instance_id\n",
    "\n",
    "    if not instance_index_dir.exists():\n",
    "        print(f\"Index not found for {instance_id}\")\n",
    "        return []\n",
    "\n",
    "    # Load index and metadata\n",
    "    index = faiss.read_index(str(instance_index_dir / \"index.faiss\"))\n",
    "    with open(instance_index_dir / \"metadata.pkl\", \"rb\") as f:\n",
    "        metadata = pickle.load(f)\n",
    "\n",
    "    # Get query embedding\n",
    "    embeddings = get_embeddings_batch([query])\n",
    "    if not embeddings:\n",
    "        print(f\"Failed to get query embedding for {instance_id}\")\n",
    "        return []\n",
    "\n",
    "    query_embedding = embeddings[0]\n",
    "    query_vec = np.array(query_embedding, dtype=np.float32).reshape(1, -1)\n",
    "    faiss.normalize_L2(query_vec)\n",
    "\n",
    "    # Search for similar chunks\n",
    "    k = min(100, index.ntotal)  # Get more chunks for max pooling\n",
    "    distances, indices = index.search(query_vec, k)\n",
    "\n",
    "    # Max pool by file\n",
    "    file_scores = defaultdict(float)\n",
    "    for idx, score in zip(indices[0], distances[0]):\n",
    "        if idx < len(metadata[\"chunk_ids\"]):\n",
    "            chunk_id = metadata[\"chunk_ids\"][idx]\n",
    "            file_path = metadata[\"chunk_to_file\"][chunk_id]\n",
    "            file_scores[file_path] = max(file_scores[file_path], score)\n",
    "\n",
    "    # Sort by score\n",
    "    sorted_files = sorted(file_scores.items(), key=lambda x: x[1], reverse=True)\n",
    "\n",
    "    return sorted_files[:top_k]\n",
    "\n",
    "\n",
    "def evaluate_recall_at_k(\n",
    "    retrieved_files: List[str], ground_truth_files: List[str], k: int = 5\n",
    ") -> float:\n",
    "    \"\"\"Calculate recall@k.\"\"\"\n",
    "    if not ground_truth_files:\n",
    "        return 0.0\n",
    "\n",
    "    retrieved_set = set(retrieved_files[:k])\n",
    "    ground_truth_set = set(ground_truth_files)\n",
    "\n",
    "    return len(retrieved_set & ground_truth_set) / len(ground_truth_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qfhMTmjZf3qB"
   },
   "source": [
    "## Load SWE-bench Lite Dataset\n",
    "\n",
    "Load the SWE-bench Lite dataset and extract ground truth file changes for evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading SWE-bench Lite dataset...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d9c9a93ab1c34070b03644f210b39caa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/3.67k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8df5ed3f8eb743e68485479a0d3f859e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6846407f85b84620a9a89881a3d41a68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/120k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "75ea7954a98e4fd5bc71a7fe63206b58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/1.12M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e0417bcc5b914519b1c89b34a6e4c355",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f6e4d4a575148f9845eb8068d95e103",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating dev split:   0%|          | 0/23 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a318482017449e39e59acbd9fb35a4b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/300 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 300 instances from SWE-bench Lite\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Main evaluation pipeline.\"\"\"\n",
    "# Load SWE-bench Lite\n",
    "instances, ground_truth_dict = load_swebench_lite()\n",
    "print(f\"Loaded {len(instances)} instances from SWE-bench Lite\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xTaPH6ekfvTW"
   },
   "source": [
    "## Main Evaluation Loop\n",
    "\n",
    "For each instance in the dataset, index the repository, retrieve relevant files for the problem statement, and compute recall@5."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Processing instance 1 of 300\n",
      "\n",
      "Indexing repository for astropy__astropy-12907...\n",
      "Found 872 Python files\n",
      "Chunking 872 files...\n",
      "Created 6660 chunks from 872 files (size increase: 7.6x)\n",
      "Preparing texts for embedding...\n",
      "Getting embeddings...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing embedding batches: 100%|██████████| 263/263 [05:22<00:00,  1.23s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating FAISS index...\n",
      "Saved index for astropy__astropy-12907 with 6660 chunks\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing embedding batches: 100%|██████████| 1/1 [00:00<00:00,  3.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Retrieved files: ['astropy/modeling/separable.py', 'astropy/modeling/tests/test_separable.py', 'astropy/modeling/tests/test_models.py', 'astropy/modeling/core.py', 'astropy/modeling/tests/test_compound.py']\n",
      "Ground truth files: ['astropy/modeling/separable.py']\n",
      "astropy__astropy-12907: Recall@5 = 1.000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "recall_scores = []\n",
    "\n",
    "for i, instance in enumerate(instances):\n",
    "    print(f'\\n\\nProcessing instance {i+1} of {len(instances)}')\n",
    "    instance_id = instance[\"instance_id\"]\n",
    "    problem_statement = instance[\"problem_statement\"]\n",
    "    ground_truth_files = ground_truth_dict[instance_id]\n",
    "\n",
    "    # Skip if no ground truth files\n",
    "    if not ground_truth_files:\n",
    "        print(f\"No ground truth files for {instance_id}, skipping...\")\n",
    "        continue\n",
    "\n",
    "    # Load repository structure\n",
    "    repo_content = load_repository_structure(repo_structures_path, instance_id)\n",
    "    if not repo_content:\n",
    "        continue\n",
    "\n",
    "    # Index repository if not already indexed\n",
    "    instance_index_dir = Path(index_dir) / instance_id\n",
    "    if not instance_index_dir.exists():\n",
    "        index_repository(repo_content, instance_id, index_dir)\n",
    "\n",
    "    # Retrieve files for the problem statement\n",
    "    retrieved_files = retrieve_files(\n",
    "        problem_statement, instance_id, index_dir, top_k=5\n",
    "    )\n",
    "    retrieved_file_paths = [f[0] for f in retrieved_files]\n",
    "    print(f\"Retrieved files: {retrieved_file_paths}\")\n",
    "    print(f\"Ground truth files: {ground_truth_files}\")\n",
    "    # Calculate recall@5\n",
    "    recall_at_5 = evaluate_recall_at_k(\n",
    "        retrieved_file_paths, ground_truth_files, k=5\n",
    "    )\n",
    "    recall_scores.append(recall_at_5)\n",
    "\n",
    "    # Convert numpy floats to regular floats for JSON serialization\n",
    "    retrieved_files_serializable = [(file_path, float(score)) for file_path, score in retrieved_files[:5]]\n",
    "\n",
    "    # Store results\n",
    "    result = {\n",
    "        \"instance_id\": instance_id,\n",
    "        \"ground_truth_files\": ground_truth_files,\n",
    "        \"retrieved_files\": retrieved_files_serializable,  # Store with scores\n",
    "        \"recall_at_5\": recall_at_5,\n",
    "    }\n",
    "    results.append(result)\n",
    "    print(f\"{instance_id}: Recall@5 = {recall_at_5:.3f}\")\n",
    "\n",
    "    # 🚨🚨🚨 - remove to evaluate on more instances 🚨🚨🚨\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Qao_M75Jfpbg"
   },
   "source": [
    "## Results and Summary\n",
    "\n",
    "Calculate and save the average recall@5 and detailed results for all evaluated instances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Evaluation complete!\n",
      "Average Recall@5: 1.000\n",
      "Results saved to /content/swebench_results.json\n"
     ]
    }
   ],
   "source": [
    "# Calculate average recall\n",
    "avg_recall = np.mean(recall_scores) if recall_scores else 0.0\n",
    "\n",
    "# Save detailed results\n",
    "final_results = {\n",
    "    \"instances\": results,\n",
    "    \"average_recall_at_5\": avg_recall,\n",
    "    \"num_instances\": len(results),\n",
    "}\n",
    "\n",
    "with open(results_file, \"w\") as f:\n",
    "    json.dump(final_results, f, indent=2)\n",
    "\n",
    "print(f\"\\nEvaluation complete!\")\n",
    "print(f\"Average Recall@5: {avg_recall:.3f}\")\n",
    "print(f\"Results saved to {results_file}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}