← Back to Cookbook
neon text to sql
Details
File: third_party/Neon/neon_text_to_sql.ipynb
Type: Jupyter Notebook
Use Cases: SQL, Database
Integrations: Neon
Content
Notebook content (JSON format):
{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Building a Text-to-SQL conversion system with Mistral AI, Neon, and LangChain\n", "\n", "Translating natural language queries into SQL statements is a powerful application of large language models (LLMs). While it's possible to ask an LLM directly to generate SQL based on a natural language prompt, it comes with limitations.\n", "\n", "1. The LLM may generate SQL that is syntactically incorrect since the SQL dialect varies across relational databases.\n", "2. The LLM doesn't have access to the full database schema, table and column names or indexes, which limits its ability to generate accurate/efficient queries. Passing in the full schema as input to the LLM everytime can get expensive.\n", "3. Pretrained LLMs can't adapt to user feedback and evolving text queries.\n", "\n", "### Finetuning\n", "\n", "An alternative is to finetune the LLM on your specific text-to-SQL dataset, which might includes query logs from your database and other relevant context. While this approach can improve the LLM's ability to generate accurate SQL queries, it still has limitations adapting continuously. Finetuning can also be expensive which might limit how frequently you can update the model.\n", "\n", "### RAG systems\n", "\n", "LLMs are great at in-context learning, so by feeding them relevant information in the prompt, we can improve their outputs. This is the idea behind Retrieval Augmented Generation (RAG) systems, which combine information retrieval with LLMs to generate more informed and contextual responses to queries.\n", "\n", "By retrieving relevant information from a knowledge base - database schemas, which tables to query, and previously generated SQL queries, we can leverage LLMs to generate SQL queries that are more accurate and efficient.\n", "\n", "### RAG for text-to-sql\n", "\n", "In this post, we'll walk through building a RAG system using [Mistral AI](https://mistral.ai/) for embeddings and language modeling, [Neon Postgres](https://neon.tech/) for the vector database. `Neon` is a fully managed serverless PostgreSQL database. It separates storage and compute to offer features such as instant branching and automatic scaling. With the `pgvector` extension, Neon can be used as a vector database to store text embeddings and query them.\n", "\n", "We'll set up a sample database, generate and store embeddings for a knowledge-base about it, and then retrieve relevant snippets to answer a query. We use the popular [LangChain](https://www.langchain.com/) library to tie it all together.\n", "\n", "Let's dive in!\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup and Dependencies\n", "\n", "### Mistral AI API\n", "\n", "Sign up at [Mistral AI](https://mistral.ai/) and navigate to the console. From the sidebar, go to the `API keys` section and create a new API key.\n", "\n", "You'll need this key to access Mistral AI's embedding and language models. Set the variable below to it.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "MISTRAL_API_KEY = \"your-mistral-api-key\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Neon Database\n", "\n", "Sign up at [Neon](https://neon.tech/) if you don't already have an account. Your Neon project comes with a ready-to-use Postgres database named `neondb` which we'll use in this notebook.\n", "\n", "Log in to the Neon Console and navigate to the Connection Details section to find your database connection string. It should look similar to this:\n", "\n", "```text\n", "postgres://alex:AbC123dEf@ep-cool-darkness-123456.us-east-2.aws.neon.tech/dbname?sslmode=require\n", "```\n", "\n", "Set the variable below to the Neon connection string.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "NEON_CONNECTION_STRING = \"your-neon-connection-string\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Python Libraries\n", "\n", "Install the necessary libraries to create the RAG system.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install langchain langchain-mistralai langchain-postgres" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`langchain-postgres` provides a `vectorstore` module that allows us to store and query embeddings in a Postgres database with `pgvector` installed. While, we need `langchain-mistralai` to interact with `Mistral` models.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparing the Data\n", "\n", "For our example, we'll leverage the commonly used Northwind sample dataset. It models a fictional trading company called `Northwind Traders` that sells products to customers. It has tables representing entities such as `Customers`, `Orders`, `Products`, and `Employees`, interconnected through relationships, allowing users to query and analyze data related to sales, inventory and other business operations.\n", "\n", "We want to provide two pieces of information as context when calling the Mistral LLM:\n", "\n", "- Relevant tables/index information from the Northwind database schema\n", "- Some sample (text-question, sql query) pairs for the LLM to learn from.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will set up retrieval by leveraging a vector database to store the schema and the sample (text, sql) pairs. We create embeddings using the `mistral-embed` LLM model for each piece of information and at query time, retrieve the relevant snippets by comparing the query embedding with the stored embeddings.\n", "\n", "We'll use the `langchain-postgres` library to store embeddings of the data in the database.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sqlalchemy\n", "\n", "# Connect to the database\n", "engine = sqlalchemy.create_engine(\n", " url=NEON_CONNECTION_STRING, pool_pre_ping=True, pool_recycle=300\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_mistralai.embeddings import MistralAIEmbeddings\n", "from langchain_postgres.vectorstores import PGVector\n", "from langchain_core.documents import Document\n", "\n", "embeds_model = MistralAIEmbeddings(model=\"mistral-embed\", api_key=MISTRAL_API_KEY)\n", "vector_store = PGVector(\n", " embeddings=embeds_model,\n", " connection=engine,\n", " use_jsonb=True,\n", " collection_name=\"text-to-sql-context\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we generate embeddings for the Northwind schema and sample queries.\n", "\n", "The `add_documents` method on a langchain vector store, like `PGVector` here uses the specified embeddings model to generate embeddings for the input text and stores them in the database.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**NOTE**: If working in Colab, download the database setup and sample query scripts by running this\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# import os\n", "# import requests\n", "\n", "# repo_url = \"https://raw.githubusercontent.com/neondatabase/mistral-neon-text-to-sql/main/data/\"\n", "# fnames = [\"northwind-schema.sql\", \"northwind-queries.jsonl\"]\n", "\n", "# os.mkdir(\"data\")\n", "# for fname in fnames:\n", "# response = requests.get(repo_url + fname)\n", "# with open(f\"data/{fname}\", \"w\") as file:\n", "# file.write(response.text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# DDL statements to create the Northwind database\n", "\n", "_all_stmts = []\n", "with open(\"data/northwind-schema.sql\", \"r\") as f:\n", " stmt = \"\"\n", " for line in f:\n", " if line.strip() == \"\" or line.startswith(\"--\"):\n", " continue\n", " else:\n", " stmt += line\n", " if \";\" in stmt:\n", " _all_stmts.append(stmt.strip())\n", " stmt = \"\"\n", "\n", "ddl_stmts = [x for x in _all_stmts if x.startswith((\"CREATE\", \"ALTER\"))]\n", "\n", "docs = [\n", " Document(page_content=stmt, metadata={\"id\": f\"ddl-{i}\", \"topic\": \"ddl\"})\n", " for i, stmt in enumerate(ddl_stmts)\n", "]\n", "vector_store.add_documents(docs, ids=[doc.metadata[\"id\"] for doc in docs])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Sample question-query pairs\n", "\n", "with open(\"data/northwind-queries.jsonl\", \"r\") as f:\n", " docs = [\n", " Document(\n", " page_content=pair,\n", " metadata={\"id\": f\"query-{i}\", \"topic\": \"query\"},\n", " )\n", " for i, pair in enumerate(f)\n", " ]\n", "\n", "vector_store.add_documents(docs, ids=[doc.metadata[\"id\"] for doc in docs])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will also create the Northwind tables in our Neon database, so we can execute the LLM output and have a working natural-language to query results pipeline.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# run the DDL script to create the database\n", "with engine.connect() as conn:\n", " with open(\"data/northwind-schema.sql\") as f:\n", " conn.execute(sqlalchemy.text(f.read()))\n", " conn.commit()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Retrieving Relevant Information\n", "\n", "With our knowledge base set up, we can now retrieve relevant information for a given query.\n", "\n", "Consider a user asking the query below.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "question = \"Find the employee who has processed the most orders and display their full name and the number of orders they have processed?\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the `similarity search` method on the vector store to retrieve the most similar snippets to the query.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "relevant_ddl_stmts = vector_store.similarity_search(\n", " query=question, k=5, filter={\"topic\": {\"$eq\": \"ddl\"}}\n", ")\n", "\n", "# relevant_ddl_stmts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also fetch some similar queries from our example corpus to add to the LLM prompt. `Few shot` prompting by providing examples of the text-to-sql conversion task in this manner helps the LLM generate more relevant output.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "similar_queries = vector_store.similarity_search(\n", " query=question, k=3, filter={\"topic\": {\"$eq\": \"query\"}}\n", ")\n", "\n", "# similar_queries" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generating the SQL output\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we'll use Mistral AI's chat model to generate a SQL statement based on the retrieved context.\n", "\n", "We first construct the prompt we pass to the Mistral AI model. The prompt includes the query, the retrieved schema snippets, and some similar queries from the corpus.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "prompt = \"\"\"\n", "You are an AI assistant that converts natural language questions into SQL queries. To do this, you will be provided with three key pieces of information:\n", "\n", "1. Some DDL statements describing tables, columns and indexes in the database:\n", "<schema>\n", "{SCHEMA}\n", "</schema>\n", "\n", "2. Some example pairs demonstrating how to convert natural language text into a corresponding SQL query for this schema: \n", "<examples>\n", "{EXAMPLES}\n", "</examples>\n", "\n", "3. The actual natural language question to convert into an SQL query:\n", "<question>\n", "{QUESTION}\n", "</question>\n", "\n", "Follow the instructions below:\n", "1. Your task is to generate an SQL query that will retrieve the data needed to answer the question, based on the database schema. \n", "2. First, carefully study the provided schema and examples to understand the structure of the database and how the examples map natural language to SQL for this schema.\n", "3. Your answer should have two parts: \n", "- Inside <scratchpad> XML tag, write out step-by-step reasoning to explain how you are generating the query based on the schema, example, and question. \n", "- Then, inside <sql> XML tag, output your generated SQL. \n", "\"\"\"\n", "\n", "schema = \"\"\n", "for stmt in relevant_ddl_stmts:\n", " schema += stmt.page_content + \"\\n\\n\"\n", "\n", "examples = \"\"\n", "for stmt in similar_queries:\n", " text_sql_pair = json.loads(stmt.page_content)\n", " examples += \"Question: \" + text_sql_pair[\"question\"] + \"\\n\"\n", " examples += \"SQL: \" + text_sql_pair[\"query\"] + \"\\n\\n\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prompting the LLM to think step by step improves the quality of the generated output. Hence, we instruct the LLM to output its reasoning and the SQL query in separate blocks in the output text.\n", "\n", "We then call the Mistral AI model to generate the SQL statement.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import re\n", "from langchain_mistralai.chat_models import ChatMistralAI\n", "from langchain_core.messages import HumanMessage\n", "\n", "chat_model = ChatMistralAI(api_key=MISTRAL_API_KEY)\n", "response = chat_model.invoke(\n", " [\n", " HumanMessage(\n", " content=prompt.format(QUESTION=question, SCHEMA=schema, EXAMPLES=examples)\n", " )\n", " ]\n", ")\n", "\n", "sql_query = re.search(r\"<sql>(.*?)</sql>\", response.content, re.DOTALL).group(1)\n", "print(sql_query)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We extract the SQL statement from the Mistral AI model's output and execute it on the Neon database to verify if it is valid.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sqlalchemy import text\n", "\n", "with engine.connect() as conn:\n", " result = conn.execute(text(sql_query))\n", " for row in result:\n", " print(row._mapping)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "Thus, we have a working text-question-to-SQL query system by leveraging the `Mistral AI` API for both chat and embedding models, and `Neon` as the vector database.\n", "\n", "To use it in production, there are some other considerations to keep in mind:\n", "\n", "1. Validate the generated SQL query, especially for destructive operations like `DELETE` and `UPDATE` before executing them. Since the text input comes from a user, it might also cause SQL injection attacks by prompting the system with malicious input.\n", "\n", "2. Monitor the system's performance and accuracy over time. We might need to update the LLM model used and the knowledge base embeddings as the data evolves.\n", "\n", "3. Better metadata. While similar examples and database schema are useful, information like data lineage and dashboard logs can add more context to the LLM API calls.\n", "\n", "4. Improving retrieval. For complex queries, we might need to increase the schema information passed to the LLM model. Further, our similarity search heuristic is pretty naive in that we are matching text queries to SQL statements. Using techniques like HyDE (Hypothetical Document Expansion) can improve the quality of the retrieved snippets.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Appendix\n", "\n", "We fetched the Northwind database setup script and some sample queries from the following helpful repositories:\n", "\n", "- [Northwind Psql](https://github.com/pthom/northwind_psql/blob/master/northwind.sql)\n", "- [Sample queries](https://github.com/eirkostop/SQL-Northwind-exercises)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }