昨日までに書いたコードは汎用性が無いので、ちょっと整理してクラスにした。モジュールにするかどうかは悩むなぁ。もうちょっとちゃんとしたいところ。
# builtin modules
import itertools
import resource
import sys
import threading
from typing import Self
# third party modules
import pandas as pd
import psycopg
from psycopg import sql
from psycopg.rows import namedtuple_row
from psycopg_pool import ConnectionPool
# increase # of file descriptors to handle many connections during execution
def increaseFileDescriptors(max_fds : int = 5000) -> None:
# increase soft limit, but keep hard limit as is
resource.setrlimit(resource.RLIMIT_NOFILE, (max_fds, resource.getrlimit(resource.RLIMIT_NOFILE)[1]))
class Age:
def __init__(self):
self.pool = None
self.graphName = None
self.name = "Pythonage"
self.version = "0.0.1"
self.description = "Pythonage is a Python package that helps you to create a graph database using Azure Database for PostgreSQL."
self.author = "Rio Fujita"
def __del__(self):
self.pool.close()
@classmethod
def connect(cls, dsn : str = None, pool_max_size : int = 844, **kwargs) -> Self:
increaseFileDescriptors(5000)
cls.pool = ConnectionPool(dsn + " options='-c search_path=ag_catalog,\"$user\",public'", max_size = pool_max_size, min_size = 16, **kwargs)
cls.pool.open()
cls.pool.wait(timeout = 30.0)
return cls
@classmethod
def loadFromDataframe(cls, graph_name : str = None, df : pd.DataFrame = None, relation : str = None) -> None:
cls.setUpGraph(cls, graph_name)
if len(keys := df.keys()) != 2:
raise ValueError(f"Dataframe must have just 2 columns, but {len(keys)} columns were found.")
df.sort_values(by = keys[0])
for key in keys:
df[key] = df[key].str.replace("'", r"\'")
cls.createVertices(cls, key, df[key].unique())
# fix me
cls.createEdges(cls, keys, zip(df[keys[0]].tolist(), df[keys[1]].tolist()), relation)
with cls.pool.connection() as conn:
with conn.cursor() as cur:
cur.execute("COMMIT;")
@classmethod
def loadFromCSV(cls, graph_name : str = None, csv_path : str = None) -> None:
cls.setUpGraph(cls, graph_name)
def setUpGraph(self, graph_name : str = None) -> None:
with self.pool.connection() as conn:
with conn.cursor(row_factory = namedtuple_row) as cur:
cur.execute(sql.SQL(f"SELECT count(*) FROM ag_graph WHERE name={sql.Literal(graph_name).as_string()}"))
if (row := cur.fetchone()) is not None:
if row.count == 1:
cur.execute(sql.SQL(f"SELECT drop_graph({sql.Literal(graph_name).as_string()}, true)"))
cur.execute(sql.SQL(f"SELECT create_graph({sql.Literal(graph_name).as_string()})"))
self.graphName = graph_name
def createVertices(self, key : str = None, nodes : list = None) -> None:
with self.pool.connection() as conn:
with conn.cursor() as cur:
query = "".join([f"SELECT * FROM cypher({sql.Literal(self.graphName).as_string()}, $$ CREATE (n:{key} {{name: '{node}'}}) $$) AS (a agtype);" for node in nodes])
cur.execute(query)
cur.execute(f'CREATE INDEX ON {self.graphName}."{key}" USING GIN (properties);')
cur.execute(f'CREATE INDEX ON {self.graphName}."{key}" USING BTREE (id);')
def createEdges(self, keys : list = None, node_pairs : list = None, relation : str = None) -> None:
threads = []
threadId = 1
for k, g in itertools.groupby(node_pairs, key = lambda x: x[0]):
t = threading.Thread(target = self.createEdgesPerStarts, args = (self.pool, self.graphName, keys, k, list(g), relation, threadId))
t.start()
threads.append(t)
threadId += 1
for i in threads:
i.join()
with self.pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(f'CREATE INDEX ON {self.graphName}."{relation}" (start_id);')
cur.execute(f'CREATE INDEX ON {self.graphName}."{relation}" (end_id);')
def createEdgesPerStarts(pool : ConnectionPool = None, graphName : str = None, keys : list = None, start_node : str = None, pairs : tuple = None, relation : str = None, threadId: int = 0) -> None:
query_ends = ', '.join([f"(m{idx}:{keys[1]} {{name: '{pair[1]}'}})" for idx, pair in enumerate(pairs)])
query_rels = ', '.join([f"(n)-[:{relation}]->(m{idx})" for idx, _ in enumerate(pairs)])
query = f"SELECT * FROM cypher({sql.Literal(graphName).as_string()}, $$ MATCH (n:{keys[0]} {{name: '{start_node}'}}), {query_ends} CREATE {query_rels} $$) AS (a agtype);"
# fix me
while True:
try:
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(query)
break
except:
pass
呼び出す側はこれだけ。
#!/usr/bin/env python3.11
# -*- coding: utf-8 -*-
import pandas as pd
from pythonage import Age
AG_GRAPH_NAME = "actorfilms"
# Global Variables
connection_string = """
host=your_server.postgres.database.azure.com
port=5432
dbname=postgres
user=your_account
password=admin_password
"""
df = pd.read_csv(f"{AG_GRAPH_NAME}.csv", usecols = ["Actor", "Film"])
ag = Age.connect(dsn = connection_string, pool_max_size = 256)
ag.loadFromDataframe(graph_name = AG_GRAPH_NAME, df = df, relation = "ACTED_IN")
ベンチした。インスタンスのサイズと、コネクションプールのコネクション数を変えながら、かかった時間。
Standard_D4ds_v4 | Standard_D8ds_v4 | Standard_D16ds_v4 | Standard_D32ds_v4 | |
---|---|---|---|---|
16 | 2:27.57 | 1:18.56 | 43.085 | 33.336 |
32 | 2:31.32 | 1:10.90 | 42.149 | 25.492 |
64 | 2:35.02 | 1:17.40 | 43.396 | 25.148 |
128 | 2:29.92 | 1:16.09 | 42.328 | 24.107 |
192 | 2:37.78 | 1:20.03 | 42.725 | 24.416 |
256 | 2:38.55 | 1:16.53 | 42.194 | 24.476 |
384 | 2:39.37 | 1:19.97 | 43.041 | 24.387 |
512 | 2:38.77 | 1:20.50 | 42.609 | 24.493 |
データ容量は大したこと無いので、ほぼCPUのコア数だけで決まるかな。ベンチマーク中にサチってるのはCPUだけで、メモリやIOPSは余裕。