← Back to Cookbook
property graph custom retriever
Details
File: third_party/LlamaIndex/propertygraphs/property_graph_custom_retriever.ipynb
Type: Jupyter Notebook
Use Cases: Property graph, Custom retrieve
Integrations: Llamaindex
Content
Notebook content (JSON format):
{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Custom Retriever in PropertyGraph\n", "\n", "In this notebook we demonstrate how to define a custom retriever for a property graph. Although this approach is more complex than using standard graph retrievers, it offers detailed control over the retrieval process, allowing for customization that aligns closely with your specific application needs.\n", "\n", "Also, we will walk you through an advanced retrieval workflow using the property graph store directly. We’ll conduct both vector search and text-to-Cypher retrieval, and subsequently integrate the results using a reranking module.\n", "\n", "We will be using cohere reranker, so we will need cohere-api-key" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install llama-index-core\n", "%pip install llama-index-llms-mistralai\n", "%pip install llama-index-embeddings-mistralai\n", "%pip install llama-index-graph-stores-neo4j\n", "%pip install llama-index-postprocessor-cohere-rerank" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import nest_asyncio\n", "\n", "nest_asyncio.apply()\n", "\n", "from IPython.display import Markdown, display" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['MISTRAL_API_KEY'] = 'YOUR MISTRAL API KEY'" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from llama_index.embeddings.mistralai import MistralAIEmbedding\n", "from llama_index.llms.mistralai import MistralAI\n", "\n", "llm = MistralAI(model='mistral-large-latest')\n", "embed_model = MistralAIEmbedding()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download Data" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2024-07-05 09:33:47-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt\n", "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8001::154, 2606:50c0:8002::154, ...\n", "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 75042 (73K) [text/plain]\n", "Saving to: ‘data/paul_graham/paul_graham_essay.txt’\n", "\n", "data/paul_graham/pa 100%[===================>] 73.28K --.-KB/s in 0.02s \n", "\n", "2024-07-05 09:33:47 (4.74 MB/s) - ‘data/paul_graham/paul_graham_essay.txt’ saved [75042/75042]\n", "\n" ] } ], "source": [ "!mkdir -p 'data/paul_graham/'\n", "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Data" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from llama_index.core import SimpleDirectoryReader\n", "\n", "documents = SimpleDirectoryReader(\"./data/paul_graham/\").load_data()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Docker Setup\n", "\n", "You need to login and set password for the first time.\n", "\n", "1. username: neo4j\n", "\n", "2. password: neo4j" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!docker run \\\n", " -p 7474:7474 -p 7687:7687 \\\n", " -v $PWD/data:/data -v $PWD/plugins:/plugins \\\n", " --name neo4j-apoc \\\n", " -e NEO4J_apoc_export_file_enabled=true \\\n", " -e NEO4J_apoc_import_file_enabled=true \\\n", " -e NEO4J_apoc_import_file_use__neo4j__config=true \\\n", " -e NEO4JLABS_PLUGINS=\\[\\\"apoc\\\"\\] \\\n", " neo4j:latest" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Neo4j GraphStore" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Received notification from DBMS server: {severity: WARNING} {code: Neo.ClientNotification.Statement.FeatureDeprecationWarning} {category: DEPRECATION} {title: This feature is deprecated and will be removed in future versions.} {description: The procedure has a deprecated field. ('config' used by 'apoc.meta.graphSample' is deprecated.)} {position: line: 1, column: 1, offset: 0} for query: \"CALL apoc.meta.graphSample() YIELD nodes, relationships RETURN nodes, [rel in relationships | {name:apoc.any.property(rel, 'type'), count: apoc.any.property(rel, 'count')}] AS relationships\"\n" ] } ], "source": [ "from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore\n", "\n", "graph_store = Neo4jPropertyGraphStore(\n", " username=\"neo4j\",\n", " password=\"llamaindex\",\n", " url=\"bolt://localhost:7687\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PropertyGraphIndex Construction" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Parsing nodes: 100%|██████████| 1/1 [00:00<00:00, 19.73it/s]\n", "Extracting paths from text: 100%|██████████| 22/22 [00:24<00:00, 1.12s/it]\n", "Extracting implicit paths: 100%|██████████| 22/22 [00:00<00:00, 11714.45it/s]\n", "Generating embeddings: 100%|██████████| 3/3 [00:01<00:00, 2.18it/s]\n", "Generating embeddings: 100%|██████████| 46/46 [00:14<00:00, 3.16it/s]\n", "Received notification from DBMS server: {severity: WARNING} {code: Neo.ClientNotification.Statement.FeatureDeprecationWarning} {category: DEPRECATION} {title: This feature is deprecated and will be removed in future versions.} {description: The procedure has a deprecated field. ('config' used by 'apoc.meta.graphSample' is deprecated.)} {position: line: 1, column: 1, offset: 0} for query: \"CALL apoc.meta.graphSample() YIELD nodes, relationships RETURN nodes, [rel in relationships | {name:apoc.any.property(rel, 'type'), count: apoc.any.property(rel, 'count')}] AS relationships\"\n" ] } ], "source": [ "from llama_index.core import PropertyGraphIndex\n", "\n", "index = PropertyGraphIndex.from_documents(\n", " documents,\n", " llm=llm,\n", " embed_model=embed_model,\n", " property_graph_store=graph_store,\n", " show_progress=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## CustomRetriever\n", "\n", "Define a custom retriever that combines VectorContextRetriever, TextToCypherRetriever and Reranker." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from llama_index.core.retrievers import (\n", " CustomPGRetriever,\n", " VectorContextRetriever,\n", " TextToCypherRetriever,\n", ")\n", "from llama_index.core.graph_stores import PropertyGraphStore\n", "from llama_index.core.vector_stores.types import VectorStore\n", "from llama_index.core.embeddings import BaseEmbedding\n", "from llama_index.core.prompts import PromptTemplate\n", "from llama_index.core.llms import LLM\n", "from llama_index.postprocessor.cohere_rerank import CohereRerank\n", "\n", "\n", "from typing import Optional, Any, Union\n", "\n", "\n", "class CustomRetriever(CustomPGRetriever):\n", " \"\"\"Custom retriever with cohere reranking.\"\"\"\n", "\n", " def init(\n", " self,\n", " ## vector context retriever params\n", " embed_model: Optional[BaseEmbedding] = None,\n", " vector_store: Optional[VectorStore] = None,\n", " similarity_top_k: int = 4,\n", " path_depth: int = 1,\n", " ## text-to-cypher params\n", " llm: Optional[LLM] = None,\n", " text_to_cypher_template: Optional[Union[PromptTemplate, str]] = None,\n", " ## cohere reranker params\n", " cohere_api_key: Optional[str] = None,\n", " cohere_top_n: int = 2,\n", " **kwargs: Any,\n", " ) -> None:\n", " \"\"\"Uses any kwargs passed in from class constructor.\"\"\"\n", "\n", " self.vector_retriever = VectorContextRetriever(\n", " self.graph_store,\n", " include_text=self.include_text,\n", " embed_model=embed_model,\n", " vector_store=vector_store,\n", " similarity_top_k=similarity_top_k,\n", " path_depth=path_depth,\n", " )\n", "\n", " self.cypher_retriever = TextToCypherRetriever(\n", " self.graph_store,\n", " llm=llm,\n", " text_to_cypher_template=text_to_cypher_template\n", " ## NOTE: you can attach other parameters here if you'd like\n", " )\n", "\n", " self.reranker = CohereRerank(\n", " api_key=cohere_api_key, top_n=cohere_top_n\n", " )\n", "\n", " def custom_retrieve(self, query_str: str) -> str:\n", " \"\"\"Define custom retriever with reranking.\n", "\n", " Could return `str`, `TextNode`, `NodeWithScore`, or a list of those.\n", " \"\"\"\n", " nodes_1 = self.vector_retriever.retrieve(query_str)\n", " nodes_2 = self.cypher_retriever.retrieve(query_str)\n", " reranked_nodes = self.reranker.postprocess_nodes(\n", " nodes_1 + nodes_2, query_str=query_str\n", " )\n", "\n", " ## TMP: please change\n", " final_text = \"\\n\\n\".join(\n", " [n.get_content(metadata_mode=\"llm\") for n in reranked_nodes]\n", " )\n", "\n", " return final_text\n", "\n", " # optional async method\n", " # async def acustom_retrieve(self, query_str: str) -> str:\n", " # ..." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from llama_index.core import Settings\n", "Settings.llm = llm\n", "Settings.embed_model=embed_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup CustomRetriever" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "custom_sub_retriever = CustomRetriever(\n", " index.property_graph_store,\n", " include_text=True,\n", " vector_store=index.vector_store,\n", " cohere_api_key=\"YOUR COHERE API KEY\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## QueryEngine" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "from llama_index.core.query_engine import RetrieverQueryEngine\n", "\n", "query_engine = RetrieverQueryEngine.from_args(\n", " index.as_retriever(sub_retrievers=[custom_sub_retriever]), llm=llm\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Received notification from DBMS server: {severity: WARNING} {code: Neo.ClientNotification.Statement.UnknownRelationshipTypeWarning} {category: UNRECOGNIZED} {title: The provided relationship type is not in the database.} {description: One of the relationship types in your query is not available in the database, make sure you didn't misspell it or that the label is available when you run this statement in your application (the missing relationship type is: WORKS_AT)} {position: line: 1, column: 20, offset: 19} for query: \"MATCH (p:PERSON)-[:WORKS_AT]->(o:ORGANIZATION)\\nWHERE p.name = 'author' AND o.name = 'Interleaf'\\nRETURN p.name, o.name, p.last_modified_date AS last_modified_date_person, o.last_modified_date AS last_modified_date_organization, p.creation_date AS creation_date_person, o.creation_date AS creation_date_organization\"\n" ] }, { "data": { "text/markdown": [ "The author worked at Interleaf, a software company that made software for creating documents, similar to Microsoft Word. He was hired as a Lisp hacker to write in their scripting language, which was a dialect of Lisp and inspired by Emacs. However, he admits to being a bad employee due to his lack of knowledge and unwillingness to learn C, as their Lisp was just a thin layer on top of a large C codebase. He also found the conventional office hours unnatural, which led to friction. Towards the end of his time there, he spent much of his time working on his own project, \"On Lisp\". Despite his shortcomings as an employee, he was well compensated, which helped him save enough money to return to RISD and pay off his college loans." ], "text/plain": [ "<IPython.core.display.Markdown object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "response = query_engine.query(\"What did author do at Interleaf?\")\n", "display(Markdown(f\"{response.response}\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "llamaindex", "language": "python", "name": "llamaindex" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 2 }