少し書き直してみた(6.4倍ぐらい速くなった)

Azure

一昨日書いたコードを思い返しながら、もうちょっと速くならんか、ちょっと試してみた。

#!/usr/bin/env python3.11
# -*- coding: utf-8 -*-

# builtin modules
import argparse
import json
import os
import sys
import threading

# third party modules
while True:
    try:
        import pandas as pd
        break
    except:
        os.system(f"{sys.executable} -m pip install pandas")

while True:
    try:
        import psycopg as pg
        from psycopg.rows import dict_row, namedtuple_row
        break
    except:
        os.system(f"{sys.executable} -m pip install psycopg")

while True:
    try:
        from pyvis.network import Network
        break
    except ModuleNotFoundError:
        os.system(f"{sys.executable} -m pip install pyvis")

# Constants
AG_GRAPH_NAME = "actorfilms"

# Global Variables
pg_conn = None

# exit with redden message
def exit_with_error(msg) -> None:
    sys.exit("\033[31m" + msg + "\033[0m")

def create_edges(cur: pg.cursor, actor: str, films: tuple, row_num: int, ln: int, threadId: int) -> None:
    query_films = ', '.join([f"(m{idx}:Film {{title: '{film}'}})" for idx, film in enumerate(films)])
    query_rels = ', '.join([f"(n)-[:ACTED_IN]->(m{idx})" for idx, _ in enumerate(films)])
    query = f"SELECT * FROM cypher('{AG_GRAPH_NAME}', $$ MATCH (n:Actor {{name: '{actor}'}}), {query_films} CREATE {query_rels} $$) AS (a agtype);"
    cur.execute(query)
    print(f"{len(films)} films for '{actor}' created, {row_num} / {ln} processed by thread {threadId}.{' '*10}", end = '\r')

# create graph database and add data
def createGraph() -> None:
    cur = pg_conn.cursor()
    print("Creating graph database...")
    try:
        cur.execute(f"SELECT * FROM ag_graph WHERE name='{AG_GRAPH_NAME}'")
    except pg.errors.UndefinedTable:
        exit_with_error("Failed to create graph database.\nPlease check if AGE extension is installed.")
    row = cur.fetchone()
    if row is not None:
        cur.execute(f"SELECT drop_graph('{AG_GRAPH_NAME}', true)")
    cur.execute(f"SELECT create_graph('{AG_GRAPH_NAME}')")

    # read csv file downloaded from https://www.kaggle.com/datasets/darinhawley/imdb-films-by-actor-for-10k-actors
    df = pd.read_csv(f"{AG_GRAPH_NAME}.csv", usecols = ["Actor", "ActorID", "Film"])
    df["Actor"] = df["Actor"].str.replace("'", r"\'")
    df["Film"] = df["Film"].str.replace("'", r"\'")

    actors = df["Actor"].unique()
    films = df["Film"].unique()

    print(f"Creating {len(actors)} vertices...")
    cur.execute(''.join([f"SELECT * FROM cypher('{AG_GRAPH_NAME}', $$ CREATE (n:Actor {{name: '{actor}'}}) $$) AS (a agtype);" for actor in actors]))
    print(f"Creating {len(films)} vertices...")
    cur.execute(''.join([f"SELECT * FROM cypher('{AG_GRAPH_NAME}', $$ CREATE (m:Film {{title: '{film}'}}) $$) AS (a agtype);" for film in films]))

    print("Creating indices...")
    cur.execute(f'CREATE INDEX ON {AG_GRAPH_NAME}."Actor" USING GIN (properties);')
    cur.execute(f'CREATE INDEX ON {AG_GRAPH_NAME}."Film" USING GIN (properties);')

    threads = []
    threadId = 1

    ln = len(df)
    row_num = 1
    films = []
    saved_actor = actors[0]
    for row in df.itertuples():
        if row.Actor == saved_actor:
            films.append(row.Film)
        else:
            t = threading.Thread(target = create_edges, args = (cur, saved_actor, films, row_num, ln, threadId))
            t.start()
            threads.append(t)
            threadId += 1
            saved_actor = row.Actor
            films = [row.Film]
        row_num += 1

    if len(films) > 0:
        t = threading.Thread(target = create_edges, args = (cur, saved_actor, films, row_num, ln, threadId))
        t.start()
        threads.append(t)

    cur.execute("COMMIT;")

# show graph
def showGraph(limit: int, film: str, actor: str) -> None:
    cur = pg_conn.cursor(row_factory = namedtuple_row)
    # If the cursor was used before in createGraph() function, the error does not occur.
    # But this is the first time the cursor is used, the error occurs because 'age.so' is not properly loaded.
    # See, https://github.com/apache/age/issues/41
    try:
        cur.execute(f"SELECT * FROM cypher('{AG_GRAPH_NAME}', $$ MATCH (n:Actor)-[r:ACTED_IN]->(m:Film) RETURN n,r,m LIMIT 1 $$) AS (n agtype, r agtype, m agtype);")
    except pg.errors.InternalError_:
        pass

    if actor is not None and actor != '':
        wk_actor = f" {{name: '{actor}'}}"
    else:
        wk_actor = ''
    if film is not None and film != '':
        wk_film = f" {{title: '{film}'}}"
    else:
        wk_film = ''

    query = f"""
        SELECT *
            FROM cypher('{AG_GRAPH_NAME}', $$
            MATCH (n:Actor{wk_actor})-[r:ACTED_IN]->(m:Film{wk_film})
            RETURN n,r,m
            LIMIT {limit}
            $$) AS (n agtype, r agtype, m agtype);
    """

    cur.execute(query)

    rows = cur.fetchall()
    if len(rows) == 0:
        exit_with_error("No records found.")

    net = Network(height = "1600px", width = "100%")
    for row in rows:
        temp_v = [json.loads(row.n[:-len("::vertex")]), json.loads(row.m[:-len("::vertex")])]
        temp_e = json.loads(row.r[:-len("::edge")])
        for v in temp_v:
            if v["label"] == "Actor":
                color = "#F79667"
                net.add_node(v["id"], label = v["properties"]["name"], color = color)
            elif v["label"] == "Film":
                color = "#57C7E3"
                net.add_node(v["id"], label = v["properties"]["title"], color = color)
            else:
                color = "blue"
        net.add_edge(temp_v[0]["id"], temp_v[1]["id"], label = temp_e["label"])

    net.toggle_physics(True)
    net.show(f"{AG_GRAPH_NAME}.html", notebook = False)

    # open the browser if the platform is macOS
    if sys.platform == "darwin":
        os.system(f"open {AG_GRAPH_NAME}.html")

def main() -> None:
    global pg_conn
    # pass initdb to create the graph database
    parser = argparse.ArgumentParser(description = "IMDB Graph")
    parser.add_argument("--initdb", "-i", action = 'store_true', help = "Initialize the database if specified")
    parser.add_argument("--limit", "-l", type = int, default = 1000, help = "Number of records to display in the graph")
    parser.add_argument("--film", "-f", default = '', help = "Film to search for")
    parser.add_argument("--actor", "-a", default  = '', help = "Actor to search for")
    args = parser.parse_args()

    # Reason why autocommit refers args.initdb is because of an error for first time connection
    # When autocommit is set to True, the connection is not transactional,
    # so we can avoid the error just with try-except-pass flow
    try:
        pg_conn = pg.connect(
            host = "your_server.postgres.database.azure.com",
            port = 5432,
            dbname = "postgres",
            user = "your_account",
            password = "your_password",
            options = "-c search_path=ag_catalog,'$user',public",
            autocommit = not args.initdb
        )
    except:
        exit_with_error("Failed to connect to the database.")

    # Run once to create the graph database
    if args.initdb:
        createGraph()

    # Show the graph
    showGraph(args.limit, args.film, args.actor)

    pg_conn.close()

if __name__ == "__main__":
    main()

変更点は以下。

  • まず、ActorとFilmのVertexを先に作成することにした。unique()してあるので、MERGEではなくCREATE。
  • Vertexを作成後、ActorとFilmにインデックスを張る。
  • Actorの出演作品をまとめて処理するcreate_edges()を追加した。その際に、スレッドで投入するようにしてみた。

で、実行してみた。

./imdb-graph.py -i  9.68s user 4.88s system 4% cpu 5:47.99 total

修正前は、cpu 37:21.04 totalだったので、6.4倍ぐらい速くなった。

追記(2024.10.27):コード中のインデックスについては、Apache AGEのインデックスを。