Apache AGEでGraphRAGを実装する(その1)

IT

GraphRAGの一歩手前、Knowledge Graphを作るところまで、この辺を参考にやってみた。生成されたhtmlをアップしておいた。Apache AGEで実装できるかを試すのが目的なので、GraphRAGの説明は無し。適宜検索して欲しい。

さて。langchain_community.graphs.age_graphにAGEGraphというのがあるが、これがいまいち動いてくれないので、そのコードを参照しつつ実装したら以下のようになった。参考にしたNotebookで使われているNeo4Jなら特に問題なく動くと思われるが、現時点ではApache AGEを使う場合、泥臭くコードを書いていかないと動かない、といったところ。それほど難易度は高くない。

#!/usr/bin/env python3.11

# built-in modules
import json
import os
import pathlib
import pickle
import re
import sys
from typing import Any, Dict, Tuple, List, Optional, Union

cur_python = f"python{sys.version_info[0]}.{sys.version_info[1]}"

# third-party modules
try:
    from langchain.chains import GraphCypherQAChain
    from langchain.text_splitter import TokenTextSplitter
except ModuleNotFoundError:
    os.system(f'{cur_python} -m pip install langchain')
    from langchain.chains import GraphCypherQAChain
    from langchain.text_splitter import TokenTextSplitter

try:
    from langchain_community.document_loaders import WikipediaLoader
    from langchain_community.graphs.graph_document import GraphDocument
except ModuleNotFoundError:
    os.system(f'{cur_python} -m pip install langchain_community')
    from langchain_community.document_loaders import WikipediaLoader
    from langchain_community.graphs.graph_document import GraphDocument

try:
    from langchain_experimental.graph_transformers import LLMGraphTransformer
except ModuleNotFoundError:
    os.system(f'{cur_python} -m pip install langchain_experimental')
    from langchain_experimental.graph_transformers import LLMGraphTransformer

try:
    from langchain_openai import ChatOpenAI, OpenAIEmbeddings
except ModuleNotFoundError:
    os.system(f'{cur_python} -m pip install langchain_openai')
    from langchain_openai import ChatOpenAI, OpenAIEmbeddings

try:
    import psycopg as pg
    from psycopg.rows import dict_row, namedtuple_row
except ModuleNotFoundError:
    os.system(f'{cur_python} -m pip install psycopg')
    import psycopg as pg
    from psycopg.rows import dict_row, namedtuple_row

try:
    from pyvis.network import Network
except ModuleNotFoundError:
    os.system(f'{cur_python} -m pip install pyvis')
    from pyvis.network import Network

# global variables
ag_graph_name = "graphrag"
graph_docs_path = pathlib.Path("graph_docs.pkl")

def load_documents():
    # Read the wikipedia article
    raw_documents = WikipediaLoader(query="Elizabeth I").load()

    # Define chunking strategy
    text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
    documents = text_splitter.split_documents(raw_documents[:3])

    return documents

def convert_to_graph_documents(documents):
    # Convert the documents to graph documents
    llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125")
    llm_transformer = LLMGraphTransformer(llm=llm)

    graph_documents = llm_transformer.convert_to_graph_documents(documents)

    return graph_documents

def format_properties(properties: Dict[str, Any], id: Union[str, None] = None) -> str:
    props = []
    # wrap property key in backticks to escape
    for k, v in properties.items():
        prop = f"`{k}`: {json.dumps(v)}"
        props.append(prop)
    if id is not None and "id" not in properties:
        props.append(
            f"id: {json.dumps(id)}" if isinstance(id, str) else f"id: {id}"
        )
    return "{" + ", ".join(props) + "}"

def clean_graph_labels(label: str) -> str:
    return re.sub("[^0-9a-zA-Z]+", "_", label)

def add_graph_documents(cur: pg.cursor, graph_documents: List[GraphDocument], include_source: bool = False) -> None:
    # query for inserting nodes
    node_insert_query = "MERGE (n:`{label}` {properties})"
    if include_source == True:
        node_insert_query += " MERGE (d:Document {d_properties}) MERGE (d)-[:MENTIONS]->(n)"
    # query for inserting edges
    edge_insert_query = """
        MERGE (from:`{f_label}` {f_properties})
        MERGE (to:`{t_label}` {t_properties})
        MERGE (from)-[:`{r_label}` {r_properties}]->(to)
        """
    for doc in graph_documents:
        # if we are adding sources, create an id for the source
        if include_source == True:
            if not doc.source.metadata.get("id"):
                doc.source.metadata["id"] = md5(
                    doc.source.page_content.encode("utf-8")
                ).hexdigest()
        # insert entity nodes
        for node in doc.nodes:
            node.properties["id"] = node.id
            if include_source:
                query = node_insert_query.format(
                    label = node.type,
                    properties = format_properties(node.properties),
                    d_properties = format_properties(doc.source.metadata),
                )
            else:
                query = node_insert_query.format(
                    label = clean_graph_labels(node.type),
                    properties = format_properties(node.properties),
                )
            cur.execute(f"SELECT * FROM cypher('{ag_graph_name}', $$ {query} $$) AS (n agtype)")
        # insert relationships
        for edge in doc.relationships:
            edge.source.properties["id"] = edge.source.id
            edge.target.properties["id"] = edge.target.id
            inputs = {
                "f_label": clean_graph_labels(edge.source.type),
                "f_properties": format_properties(edge.source.properties),
                "t_label": clean_graph_labels(edge.target.type),
                "t_properties": format_properties(edge.target.properties),
                "r_label": clean_graph_labels(edge.type).upper(),
                "r_properties": format_properties(edge.properties),
            }
            query = edge_insert_query.format(**inputs)
            cur.execute(f"SELECT * FROM cypher('{ag_graph_name}', $$ {query} $$) AS (n agtype)")

def showGraph(cur: pg.cursor):
    # '!MENTIONS' expression in the original code doesn't work with AGE
    # MATCH (s)-[r:!MENTIONS]->(t)
    cur.execute(f"""
        SELECT *
            FROM cypher ('{ag_graph_name}', $$
            MATCH (s)-[r]->(t)
            RETURN s,r,t
            LIMIT 50
            $$) AS (s agtype, r agtype, t agtype)
    """)
    rows = cur.fetchall()
    net = Network(height="1600px", width="100%")
    for row in rows:
        temp_v = [json.loads(row.s[:-len("::vertex")]), json.loads(row.t[:-len("::vertex")])]
        temp_e = json.loads(row.r[:-len("::edge")])
        for v in temp_v:
            if v["label"] == "Person":
                color = "#F79667"
            elif v["label"] == "Country":
                color = "#57C7E3"
            elif v["label"] == "Monarch":
                color = "#D9C8AE"
            elif v["label"] == "Title":
                color = "#F16767"
            else:
                color = "blue"
            print(v)
            net.add_node(v["id"], label=v["properties"]["id"], color=color)
        net.add_edge(temp_v[0]["id"], temp_v[1]["id"], label=temp_e["label"])
    net.toggle_physics(True)
    net.show("graphrag.html", notebook=False, )
    os.system('open graphrag.html')

def main() -> None:
    print("Loading documents...")
    if graph_docs_path.exists() == False:
        docs = load_documents()
        print("Converting documents to graph documents...")
        graph_docs = convert_to_graph_documents(docs)
        with graph_docs_path.open(mode = "wb") as f:
            pickle.dump(graph_docs, f)
    else:
        with graph_docs_path.open(mode = "rb") as f:
            graph_docs = pickle.load(f)

    with pg.connect(
                host="localhost",
                port=5455,
                dbname="postgres",
                user=os.environ.get('USER'),
                password="passw0rd",
                options="-c search_path=ag_catalog,'$user',public"
        ) as conn:
        with conn.cursor(row_factory = namedtuple_row) as cur:
            print("Creating graph database...")
            cur.execute(f"SELECT * FROM ag_graph WHERE name='{ag_graph_name}'")
            row = cur.fetchone()
            if row is not None:
                cur.execute(f"SELECT drop_graph('{ag_graph_name}', true)")
            cur.execute(f"SELECT create_graph('{ag_graph_name}')")
            print("Adding graph documents to AGE...")
            add_graph_documents(cur, graph_docs)

        with conn.cursor(row_factory = namedtuple_row) as cur:
            print("Showing graph...")
            showGraph(cur)

if __name__ == "__main__":
    main()

処理の大まかな流れは以下。

  • 182〜190行:OpenAIのAPIを毎回叩くと料金が辛いので、グラフドキュメントを一度生成したら、Pickleでシリアライズしておく。デバッグ中に、1回$0.01でも何度も叩くうちにアカンってなったw
  • 61行・load_documents():Wikipediaから取ってきたドキュメントをチャンキングして返すだけ。
  • 71行・convert_to_graph_documents():OpenAIのAPIでLLMを取得して、ドキュメントをグラフドキュメントにする。中身は以下のようになっているので、自前でパースしても大したことはない。
[
    GraphDocument(
        nodes=[
            Node(id='Elizabeth I', type='Person'), 
            Node(id='England', type='Country'), 
            Node(id='Ireland', type='Country'), 
            ......
        ], 
        relationships=[
            Relationship(source=Node(id='Elizabeth I', type='Person'), target=Node(id='England', type='Country'), type='RULER'), 
            Relationship(source=Node(id='Elizabeth I', type='Person'), target=Node(id='Ireland', type='Country'), type='RULER'), 
            Relationship(source=Node(id='Elizabeth I', type='Person'), target=Node(id='House Of Tudor', type='Royal family'), type='MEMBER'), 
            ......
        ], 
        source=Document(
            metadata={
                'title': 'Elizabeth I', 
                'summary': 'Elizabeth I (7 September 1533 – 24 March 1603) was Queen of ......
                'source': 'https://en.wikipedia.org/wiki/Elizabeth_I'
            }, 
            page_content='Elizabeth I (7 September 1533 – 24 March 1603) was Queen .....
......
  • 192〜208行:グラフドキュメントをApache AGEを有効にしたPostgreSQLに投入する。
  • 98行・add_graph_documents():本来は、langchain_community.graphs.age_graph.AGEGraphで出来るはずだが、現時点では色々な問題があることが分かったので、自前で。要はCypherでVertexとEdgeを投入しているだけ。
  • 146行・showGraph():上で投入したグラフを50件だけ取得して、pyvisのネットワークに変換して表示している。先日書いたJaalだとこれはうまくいかないので、pyvisを使うことにした。

Knowledge Graphが出来たので、次はこのデータに対してクエリを実行するだけ。

タイトルとURLをコピーしました