← Back to Cookbook
dequantization
Details
File: mistral/embeddings/dequantization.ipynb
Type: Jupyter Notebook
Use Cases: Dequantization
Content
Notebook content (JSON format):
{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "FaMNLtXRK8wS" }, "source": [ "# Dequantization embeddings\n", "\n", "This notebook demonstrates how to dequantize various quantized representations (int8, uint8, binary, ubinary) of embeddings back to float values using PyTorch and NumPy.\n", "\n", "## Installation\n", "\n", "Install the required packages:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (2.0.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.13.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n", "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n", " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n", " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n", " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n", " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n", " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n", " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n", " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n", " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n", " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n", " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", "Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m19.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m28.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m65.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n", " Attempting uninstall: nvidia-nvjitlink-cu12\n", " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", " Attempting uninstall: nvidia-curand-cu12\n", " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", " Attempting uninstall: nvidia-cufft-cu12\n", " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", " Attempting uninstall: nvidia-cuda-runtime-cu12\n", " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", " Attempting uninstall: nvidia-cuda-cupti-cu12\n", " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", " Attempting uninstall: nvidia-cublas-cu12\n", " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", " Attempting uninstall: nvidia-cusparse-cu12\n", " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", " Attempting uninstall: nvidia-cudnn-cu12\n", " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", " Attempting uninstall: nvidia-cusolver-cu12\n", " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n" ] } ], "source": [ "!pip install torch numpy" ] }, { "cell_type": "markdown", "metadata": { "id": "v4NQgKwtLM2P" }, "source": [ "## define the dequantization functions" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "\n", "def _unpack_sign_bits(x_packed: torch.Tensor, signed: bool, C: int) -> torch.Tensor:\n", " \"\"\"\n", " Unpacks a bit-packed tensor into a ±1 float32 tensor.\n", "\n", " Parameters:\n", " ----------\n", " x_packed : torch.Tensor\n", " A tensor containing packed bit values (usually uint8) representing the sign bits.\n", " signed : bool\n", " Indicates if the original data was signed ('binary') or unsigned ('ubinary').\n", " C : int\n", " The number of original channels or dimensions to unpack (used to trim padding bits).\n", "\n", " Returns:\n", " -------\n", " torch.Tensor\n", " A float32 tensor on the same device as x_packed, with values ±1 and shape (..., C).\n", " \"\"\"\n", " arr = (x_packed.cpu().numpy() + (128 if signed else 0)).astype(np.uint8)\n", " bits = np.unpackbits(arr, axis=-1)[..., :C] # remove pad bits\n", " return torch.from_numpy(bits * 2 - 1).float().to(x_packed.device)\n", "\n", "def dequantize(\n", " q: torch.Tensor | list, quant: str, orig_dim: int | None = None\n", ") -> torch.Tensor:\n", " \"\"\"\n", " Dequantizes a quantized tensor or list back to float32 values in [-1, 1].\n", "\n", " Parameters:\n", " ----------\n", " q : torch.Tensor or list\n", " The quantized data, either as a tensor or list.\n", " quant : str\n", " The quantization type. Supported values:\n", " - 'fp8': already float, just converted.\n", " - 'int8': scaled by 127.\n", " - 'uint8': scaled and shifted to [-1,1].\n", " - 'binary' or 'ubinary': unpacked sign bits (requires orig_dim).\n", " orig_dim : int or None, optional\n", " The original number of dimensions/channels before packing,\n", " required when quant is 'binary' or 'ubinary' to correctly unpack bits.\n", "\n", " Returns:\n", " -------\n", " torch.Tensor\n", " A float32 tensor of shape (B, T, C) with values in [-1, 1].\n", "\n", " Raises:\n", " ------\n", " ValueError\n", " If an unsupported quantization type is provided or if `orig_dim` is missing\n", " for 'binary'/'ubinary' unpacking.\n", " \"\"\"\n", " if isinstance(q, list):\n", " q = torch.tensor(q)\n", " if quant == \"fp8\":\n", " return q.float()\n", " if quant == \"int8\":\n", " return q.float() / 127.0\n", " if quant == \"uint8\":\n", " return q.float() / 127.5 - 1.0\n", " if quant in {\"binary\", \"ubinary\"}:\n", " if orig_dim is None:\n", " raise ValueError(\"orig_dim needed for (u)binary unpack\")\n", " return _unpack_sign_bits(q, quant == \"binary\", orig_dim)\n", " raise ValueError(f\"Invalid quantization {quant}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "96Mwm3dMLcnX" }, "source": [ "## Examples" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "embed_float = [-0.11944580078125,-0.2734375,0.040771484375,0.3056640625,-0.1470947265625,-0.11749267578125,0.0799560546875,0.08282470703125,-0.04205322265625,0.220947265625,0.0015048980712890625,-0.00397491455078125,-0.01099395751953125,-0.052642822265625,0.0504150390625,0.01605224609375,0.029693603515625,-0.024078369140625]\n", "embed_int = [-15,-35,5,39,-19,-15,10,11,-5,28,0,0,-2,-7,6,2,4,-3]\n", "embed_uint = [112,93,133,166,109,112,138,138,122,156,128,127,126,121,134,130,131,124]\n", "embed_bin = [-77,-29,0]\n", "embed_ubin = [51,99,128]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([255., 255., 1., 1., 255., 255., 1., 1., 255., 1., 1., 255.,\n", " 255., 255., 1., 1., 1., 255.])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dequantize(embed_bin, quant=\"binary\", orig_dim=18)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([255., 255., 1., 1., 255., 255., 1., 1., 255., 1., 1., 255.,\n", " 255., 255., 1., 1., 1., 255.])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dequantize(embed_ubin, quant=\"ubinary\", orig_dim=18)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-0.1181, -0.2756, 0.0394, 0.3071, -0.1496, -0.1181, 0.0787, 0.0866,\n", " -0.0394, 0.2205, 0.0000, 0.0000, -0.0157, -0.0551, 0.0472, 0.0157,\n", " 0.0315, -0.0236])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dequantize(embed_int, quant=\"int8\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-0.1216, -0.2706, 0.0431, 0.3020, -0.1451, -0.1216, 0.0824, 0.0824,\n", " -0.0431, 0.2235, 0.0039, -0.0039, -0.0118, -0.0510, 0.0510, 0.0196,\n", " 0.0275, -0.0275])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dequantize(embed_uint, quant=\"uint8\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }