Reach out
← Back to Cookbook

dequantization

Details

File: mistral/embeddings/dequantization.ipynb

Type: Jupyter Notebook

Use Cases: Dequantization

Content

Notebook content (JSON format):

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FaMNLtXRK8wS"
   },
   "source": [
    "# Dequantization embeddings\n",
    "\n",
    "This notebook demonstrates how to dequantize various quantized representations (int8, uint8, binary, ubinary) of embeddings back to float values using PyTorch and NumPy.\n",
    "\n",
    "## Installation\n",
    "\n",
    "Install the required packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (2.0.2)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
      "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.13.2)\n",
      "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n",
      "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n",
      "  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n",
      "  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n",
      "  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
      "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n",
      "  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
      "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n",
      "  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n",
      "  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n",
      "  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n",
      "  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
      "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n",
      "  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
      "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
      "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n",
      "  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
      "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
      "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
      "Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m19.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m28.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m65.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n",
      "  Attempting uninstall: nvidia-nvjitlink-cu12\n",
      "    Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
      "    Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
      "      Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
      "  Attempting uninstall: nvidia-curand-cu12\n",
      "    Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
      "    Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
      "      Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
      "  Attempting uninstall: nvidia-cufft-cu12\n",
      "    Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
      "    Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
      "      Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
      "  Attempting uninstall: nvidia-cuda-runtime-cu12\n",
      "    Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
      "    Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
      "      Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
      "  Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
      "    Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
      "    Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
      "      Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
      "  Attempting uninstall: nvidia-cuda-cupti-cu12\n",
      "    Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
      "    Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
      "      Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
      "  Attempting uninstall: nvidia-cublas-cu12\n",
      "    Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
      "    Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
      "      Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
      "  Attempting uninstall: nvidia-cusparse-cu12\n",
      "    Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
      "    Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
      "      Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
      "  Attempting uninstall: nvidia-cudnn-cu12\n",
      "    Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
      "    Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
      "      Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
      "  Attempting uninstall: nvidia-cusolver-cu12\n",
      "    Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
      "    Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
      "      Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
      "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n"
     ]
    }
   ],
   "source": [
    "!pip install torch numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v4NQgKwtLM2P"
   },
   "source": [
    "## define the dequantization functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def _unpack_sign_bits(x_packed: torch.Tensor, signed: bool, C: int) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Unpacks a bit-packed tensor into a ±1 float32 tensor.\n",
    "\n",
    "    Parameters:\n",
    "    ----------\n",
    "    x_packed : torch.Tensor\n",
    "        A tensor containing packed bit values (usually uint8) representing the sign bits.\n",
    "    signed : bool\n",
    "        Indicates if the original data was signed ('binary') or unsigned ('ubinary').\n",
    "    C : int\n",
    "        The number of original channels or dimensions to unpack (used to trim padding bits).\n",
    "\n",
    "    Returns:\n",
    "    -------\n",
    "    torch.Tensor\n",
    "        A float32 tensor on the same device as x_packed, with values ±1 and shape (..., C).\n",
    "    \"\"\"\n",
    "    arr = (x_packed.cpu().numpy() + (128 if signed else 0)).astype(np.uint8)\n",
    "    bits = np.unpackbits(arr, axis=-1)[..., :C]  # remove pad bits\n",
    "    return torch.from_numpy(bits * 2 - 1).float().to(x_packed.device)\n",
    "\n",
    "def dequantize(\n",
    "    q: torch.Tensor | list, quant: str, orig_dim: int | None = None\n",
    ") -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    Dequantizes a quantized tensor or list back to float32 values in [-1, 1].\n",
    "\n",
    "    Parameters:\n",
    "    ----------\n",
    "    q : torch.Tensor or list\n",
    "        The quantized data, either as a tensor or list.\n",
    "    quant : str\n",
    "        The quantization type. Supported values:\n",
    "        - 'fp8': already float, just converted.\n",
    "        - 'int8': scaled by 127.\n",
    "        - 'uint8': scaled and shifted to [-1,1].\n",
    "        - 'binary' or 'ubinary': unpacked sign bits (requires orig_dim).\n",
    "    orig_dim : int or None, optional\n",
    "        The original number of dimensions/channels before packing,\n",
    "        required when quant is 'binary' or 'ubinary' to correctly unpack bits.\n",
    "\n",
    "    Returns:\n",
    "    -------\n",
    "    torch.Tensor\n",
    "        A float32 tensor of shape (B, T, C) with values in [-1, 1].\n",
    "\n",
    "    Raises:\n",
    "    ------\n",
    "    ValueError\n",
    "        If an unsupported quantization type is provided or if `orig_dim` is missing\n",
    "        for 'binary'/'ubinary' unpacking.\n",
    "    \"\"\"\n",
    "    if isinstance(q, list):\n",
    "        q = torch.tensor(q)\n",
    "    if quant == \"fp8\":\n",
    "        return q.float()\n",
    "    if quant == \"int8\":\n",
    "        return q.float() / 127.0\n",
    "    if quant == \"uint8\":\n",
    "        return q.float() / 127.5 - 1.0\n",
    "    if quant in {\"binary\", \"ubinary\"}:\n",
    "        if orig_dim is None:\n",
    "            raise ValueError(\"orig_dim needed for (u)binary unpack\")\n",
    "        return _unpack_sign_bits(q, quant == \"binary\", orig_dim)\n",
    "    raise ValueError(f\"Invalid quantization {quant}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "96Mwm3dMLcnX"
   },
   "source": [
    "## Examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_float = [-0.11944580078125,-0.2734375,0.040771484375,0.3056640625,-0.1470947265625,-0.11749267578125,0.0799560546875,0.08282470703125,-0.04205322265625,0.220947265625,0.0015048980712890625,-0.00397491455078125,-0.01099395751953125,-0.052642822265625,0.0504150390625,0.01605224609375,0.029693603515625,-0.024078369140625]\n",
    "embed_int = [-15,-35,5,39,-19,-15,10,11,-5,28,0,0,-2,-7,6,2,4,-3]\n",
    "embed_uint = [112,93,133,166,109,112,138,138,122,156,128,127,126,121,134,130,131,124]\n",
    "embed_bin = [-77,-29,0]\n",
    "embed_ubin = [51,99,128]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([255., 255.,   1.,   1., 255., 255.,   1.,   1., 255.,   1.,   1., 255.,\n",
       "        255., 255.,   1.,   1.,   1., 255.])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dequantize(embed_bin, quant=\"binary\", orig_dim=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([255., 255.,   1.,   1., 255., 255.,   1.,   1., 255.,   1.,   1., 255.,\n",
       "        255., 255.,   1.,   1.,   1., 255.])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dequantize(embed_ubin, quant=\"ubinary\", orig_dim=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.1181, -0.2756,  0.0394,  0.3071, -0.1496, -0.1181,  0.0787,  0.0866,\n",
       "        -0.0394,  0.2205,  0.0000,  0.0000, -0.0157, -0.0551,  0.0472,  0.0157,\n",
       "         0.0315, -0.0236])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dequantize(embed_int, quant=\"int8\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.1216, -0.2706,  0.0431,  0.3020, -0.1451, -0.1216,  0.0824,  0.0824,\n",
       "        -0.0431,  0.2235,  0.0039, -0.0039, -0.0118, -0.0510,  0.0510,  0.0196,\n",
       "         0.0275, -0.0275])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dequantize(embed_uint, quant=\"uint8\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}