← Back to Cookbook
dequantization
Details
File: mistral/embeddings/dequantization.ipynb
Type: Jupyter Notebook
Use Cases: Dequantization
Content
Notebook content (JSON format):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FaMNLtXRK8wS"
},
"source": [
"# Dequantization embeddings\n",
"\n",
"This notebook demonstrates how to dequantize various quantized representations (int8, uint8, binary, ubinary) of embeddings back to float values using PyTorch and NumPy.\n",
"\n",
"## Installation\n",
"\n",
"Install the required packages:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (2.0.2)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.13.2)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n",
" Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n",
" Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n",
" Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n",
" Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n",
" Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n",
" Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n",
" Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n",
" Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n",
" Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n",
" Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
"Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m19.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m28.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m65.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n",
" Attempting uninstall: nvidia-nvjitlink-cu12\n",
" Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
" Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
" Attempting uninstall: nvidia-curand-cu12\n",
" Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
" Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
" Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
" Attempting uninstall: nvidia-cufft-cu12\n",
" Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
" Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
" Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
" Attempting uninstall: nvidia-cuda-runtime-cu12\n",
" Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
" Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cuda-cupti-cu12\n",
" Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cublas-cu12\n",
" Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
" Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
" Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
" Attempting uninstall: nvidia-cusparse-cu12\n",
" Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
" Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
" Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
" Attempting uninstall: nvidia-cudnn-cu12\n",
" Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
" Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
" Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
" Attempting uninstall: nvidia-cusolver-cu12\n",
" Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
" Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
" Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
"Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n"
]
}
],
"source": [
"!pip install torch numpy"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v4NQgKwtLM2P"
},
"source": [
"## define the dequantization functions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"\n",
"def _unpack_sign_bits(x_packed: torch.Tensor, signed: bool, C: int) -> torch.Tensor:\n",
" \"\"\"\n",
" Unpacks a bit-packed tensor into a ±1 float32 tensor.\n",
"\n",
" Parameters:\n",
" ----------\n",
" x_packed : torch.Tensor\n",
" A tensor containing packed bit values (usually uint8) representing the sign bits.\n",
" signed : bool\n",
" Indicates if the original data was signed ('binary') or unsigned ('ubinary').\n",
" C : int\n",
" The number of original channels or dimensions to unpack (used to trim padding bits).\n",
"\n",
" Returns:\n",
" -------\n",
" torch.Tensor\n",
" A float32 tensor on the same device as x_packed, with values ±1 and shape (..., C).\n",
" \"\"\"\n",
" arr = (x_packed.cpu().numpy() + (128 if signed else 0)).astype(np.uint8)\n",
" bits = np.unpackbits(arr, axis=-1)[..., :C] # remove pad bits\n",
" return torch.from_numpy(bits * 2 - 1).float().to(x_packed.device)\n",
"\n",
"def dequantize(\n",
" q: torch.Tensor | list, quant: str, orig_dim: int | None = None\n",
") -> torch.Tensor:\n",
" \"\"\"\n",
" Dequantizes a quantized tensor or list back to float32 values in [-1, 1].\n",
"\n",
" Parameters:\n",
" ----------\n",
" q : torch.Tensor or list\n",
" The quantized data, either as a tensor or list.\n",
" quant : str\n",
" The quantization type. Supported values:\n",
" - 'fp8': already float, just converted.\n",
" - 'int8': scaled by 127.\n",
" - 'uint8': scaled and shifted to [-1,1].\n",
" - 'binary' or 'ubinary': unpacked sign bits (requires orig_dim).\n",
" orig_dim : int or None, optional\n",
" The original number of dimensions/channels before packing,\n",
" required when quant is 'binary' or 'ubinary' to correctly unpack bits.\n",
"\n",
" Returns:\n",
" -------\n",
" torch.Tensor\n",
" A float32 tensor of shape (B, T, C) with values in [-1, 1].\n",
"\n",
" Raises:\n",
" ------\n",
" ValueError\n",
" If an unsupported quantization type is provided or if `orig_dim` is missing\n",
" for 'binary'/'ubinary' unpacking.\n",
" \"\"\"\n",
" if isinstance(q, list):\n",
" q = torch.tensor(q)\n",
" if quant == \"fp8\":\n",
" return q.float()\n",
" if quant == \"int8\":\n",
" return q.float() / 127.0\n",
" if quant == \"uint8\":\n",
" return q.float() / 127.5 - 1.0\n",
" if quant in {\"binary\", \"ubinary\"}:\n",
" if orig_dim is None:\n",
" raise ValueError(\"orig_dim needed for (u)binary unpack\")\n",
" return _unpack_sign_bits(q, quant == \"binary\", orig_dim)\n",
" raise ValueError(f\"Invalid quantization {quant}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "96Mwm3dMLcnX"
},
"source": [
"## Examples"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"embed_float = [-0.11944580078125,-0.2734375,0.040771484375,0.3056640625,-0.1470947265625,-0.11749267578125,0.0799560546875,0.08282470703125,-0.04205322265625,0.220947265625,0.0015048980712890625,-0.00397491455078125,-0.01099395751953125,-0.052642822265625,0.0504150390625,0.01605224609375,0.029693603515625,-0.024078369140625]\n",
"embed_int = [-15,-35,5,39,-19,-15,10,11,-5,28,0,0,-2,-7,6,2,4,-3]\n",
"embed_uint = [112,93,133,166,109,112,138,138,122,156,128,127,126,121,134,130,131,124]\n",
"embed_bin = [-77,-29,0]\n",
"embed_ubin = [51,99,128]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([255., 255., 1., 1., 255., 255., 1., 1., 255., 1., 1., 255.,\n",
" 255., 255., 1., 1., 1., 255.])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dequantize(embed_bin, quant=\"binary\", orig_dim=18)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([255., 255., 1., 1., 255., 255., 1., 1., 255., 1., 1., 255.,\n",
" 255., 255., 1., 1., 1., 255.])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dequantize(embed_ubin, quant=\"ubinary\", orig_dim=18)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.1181, -0.2756, 0.0394, 0.3071, -0.1496, -0.1181, 0.0787, 0.0866,\n",
" -0.0394, 0.2205, 0.0000, 0.0000, -0.0157, -0.0551, 0.0472, 0.0157,\n",
" 0.0315, -0.0236])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dequantize(embed_int, quant=\"int8\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.1216, -0.2706, 0.0431, 0.3020, -0.1451, -0.1216, 0.0824, 0.0824,\n",
" -0.0431, 0.2235, 0.0039, -0.0039, -0.0118, -0.0510, 0.0510, 0.0196,\n",
" 0.0275, -0.0275])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dequantize(embed_uint, quant=\"uint8\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}