追記(2024.10.28):最新版
一昨日書いたコードを思い返しながら、もうちょっと速くならんか、ちょっと試してみた。
#!/usr/bin/env python3.11
# -*- coding: utf-8 -*-
# builtin modules
import argparse
import json
import os
import sys
import threading
# third party modules
while True:
try:
import pandas as pd
break
except:
os.system(f"{sys.executable} -m pip install pandas")
while True:
try:
import psycopg as pg
from psycopg.rows import dict_row, namedtuple_row
break
except:
os.system(f"{sys.executable} -m pip install psycopg")
while True:
try:
from pyvis.network import Network
break
except ModuleNotFoundError:
os.system(f"{sys.executable} -m pip install pyvis")
# Constants
AG_GRAPH_NAME = "actorfilms"
# Global Variables
pg_conn = None
# exit with redden message
def exit_with_error(msg) -> None:
sys.exit("\033[31m" + msg + "\033[0m")
def create_edges(cur: pg.cursor, actor: str, films: tuple, row_num: int, ln: int, threadId: int) -> None:
query_films = ', '.join([f"(m{idx}:Film {{title: '{film}'}})" for idx, film in enumerate(films)])
query_rels = ', '.join([f"(n)-[:ACTED_IN]->(m{idx})" for idx, _ in enumerate(films)])
query = f"SELECT * FROM cypher('{AG_GRAPH_NAME}', $$ MATCH (n:Actor {{name: '{actor}'}}), {query_films} CREATE {query_rels} $$) AS (a agtype);"
cur.execute(query)
print(f"{len(films)} films for '{actor}' created, {row_num} / {ln} processed by thread {threadId}.{' '*10}", end = '\r')
# create graph database and add data
def createGraph() -> None:
cur = pg_conn.cursor()
print("Creating graph database...")
try:
cur.execute(f"SELECT * FROM ag_graph WHERE name='{AG_GRAPH_NAME}'")
except pg.errors.UndefinedTable:
exit_with_error("Failed to create graph database.\nPlease check if AGE extension is installed.")
row = cur.fetchone()
if row is not None:
cur.execute(f"SELECT drop_graph('{AG_GRAPH_NAME}', true)")
cur.execute(f"SELECT create_graph('{AG_GRAPH_NAME}')")
# read csv file downloaded from https://www.kaggle.com/datasets/darinhawley/imdb-films-by-actor-for-10k-actors
df = pd.read_csv(f"{AG_GRAPH_NAME}.csv", usecols = ["Actor", "ActorID", "Film"])
df["Actor"] = df["Actor"].str.replace("'", r"\'")
df["Film"] = df["Film"].str.replace("'", r"\'")
actors = df["Actor"].unique()
films = df["Film"].unique()
print(f"Creating {len(actors)} vertices...")
cur.execute(''.join([f"SELECT * FROM cypher('{AG_GRAPH_NAME}', $$ CREATE (n:Actor {{name: '{actor}'}}) $$) AS (a agtype);" for actor in actors]))
print(f"Creating {len(films)} vertices...")
cur.execute(''.join([f"SELECT * FROM cypher('{AG_GRAPH_NAME}', $$ CREATE (m:Film {{title: '{film}'}}) $$) AS (a agtype);" for film in films]))
print("Creating indices...")
cur.execute(f'CREATE INDEX ON {AG_GRAPH_NAME}."Actor" USING GIN (properties);')
cur.execute(f'CREATE INDEX ON {AG_GRAPH_NAME}."Actor" (id);')
cur.execute(f'CREATE INDEX ON {AG_GRAPH_NAME}."Film" USING GIN (properties);')
cur.execute(f'CREATE INDEX ON {AG_GRAPH_NAME}."Film" (id);')
threads = []
threadId = 1
ln = len(df)
row_num = 1
films = []
saved_actor = actors[0]
for row in df.itertuples():
if row.Actor == saved_actor:
films.append(row.Film)
else:
t = threading.Thread(target = create_edges, args = (cur, saved_actor, films, row_num, ln, threadId))
t.start()
threads.append(t)
threadId += 1
saved_actor = row.Actor
films = [row.Film]
row_num += 1
if len(films) > 0:
t = threading.Thread(target = create_edges, args = (cur, saved_actor, films, row_num, ln, threadId))
t.start()
threads.append(t)
# added 2024.10.28
cur.execute(f'CREATE INDEX ON {AG_GRAPH_NAME}."ACTED_IN" (start_id);')
cur.execute(f'CREATE INDEX ON {AG_GRAPH_NAME}."ACTED_IN" (end_id);')
cur.execute("COMMIT;")
# show graph
def showGraph(limit: int, film: str, actor: str) -> None:
cur = pg_conn.cursor(row_factory = namedtuple_row)
if actor is not None and actor != '':
wk_actor = f" {{name: '{actor}'}}"
else:
wk_actor = ''
if film is not None and film != '':
wk_film = f" {{title: '{film}'}}"
else:
wk_film = ''
query = f"""
SELECT *
FROM cypher('{AG_GRAPH_NAME}', $$
MATCH (n:Actor{wk_actor})-[r:ACTED_IN]->(m:Film{wk_film})
RETURN n,r,m
LIMIT {limit}
$$) AS (n agtype, r agtype, m agtype);
"""
cur.execute(query)
rows = cur.fetchall()
if len(rows) == 0:
exit_with_error("No records found.")
net = Network(height = "1600px", width = "100%")
for row in rows:
temp_v = [json.loads(row.n[:-len("::vertex")]), json.loads(row.m[:-len("::vertex")])]
temp_e = json.loads(row.r[:-len("::edge")])
for v in temp_v:
if v["label"] == "Actor":
color = "#F79667"
net.add_node(v["id"], label = v["properties"]["name"], color = color)
elif v["label"] == "Film":
color = "#57C7E3"
net.add_node(v["id"], label = v["properties"]["title"], color = color)
else:
color = "blue"
net.add_edge(temp_v[0]["id"], temp_v[1]["id"], label = temp_e["label"])
net.toggle_physics(True)
net.show(f"{AG_GRAPH_NAME}.html", notebook = False)
# open the browser if the platform is macOS
if sys.platform == "darwin":
os.system(f"open {AG_GRAPH_NAME}.html")
def main() -> None:
global pg_conn
# pass initdb to create the graph database
parser = argparse.ArgumentParser(description = "IMDB Graph")
parser.add_argument("--initdb", "-i", action = 'store_true', help = "Initialize the database if specified")
parser.add_argument("--limit", "-l", type = int, default = 1000, help = "Number of records to display in the graph")
parser.add_argument("--film", "-f", default = '', help = "Film to search for")
parser.add_argument("--actor", "-a", default = '', help = "Actor to search for")
args = parser.parse_args()
try:
pg_conn = pg.connect(
host = "your_server.postgres.database.azure.com",
port = 5432,
dbname = "postgres",
user = "your_account",
password = "your_password",
options = "-c search_path=ag_catalog,'$user',public"
)
except:
exit_with_error("Failed to connect to the database.")
# Run once to create the graph database
if args.initdb:
createGraph()
# Show the graph
showGraph(args.limit, args.film, args.actor)
pg_conn.close()
if __name__ == "__main__":
main()
変更点は以下。
- まず、ActorとFilmのVertexを先に作成することにした。unique()してあるので、MERGEではなくCREATE。
- Vertexを作成後、ActorとFilmにインデックスを張る。
- Actorの出演作品をまとめて処理するcreate_edges()を追加した。その際に、スレッドで投入するようにしてみた。
で、実行してみた。
./imdb-graph.py -i 9.68s user 4.88s system 4% cpu 5:47.99 total
修正前は、cpu 37:21.04 totalだったので、6.4倍ぐらい速くなった。
追記(2024.10.27):コード中のインデックスについては、Apache AGEのインデックスを。