AGE用のクラスを作った

昨日までに書いたコードは汎用性が無いので、ちょっと整理してクラスにした。モジュールにするかどうかは悩むなぁ。もうちょっとちゃんとしたいところ。

# builtin modules
import itertools
import resource
import sys
import threading
from typing import Self

# third party modules
import pandas as pd

import psycopg
from psycopg import sql
from psycopg.rows import namedtuple_row

from psycopg_pool import ConnectionPool

# increase # of file descriptors to handle many connections during execution
def increaseFileDescriptors(max_fds : int = 5000) -> None:
    # increase soft limit, but keep hard limit as is
    resource.setrlimit(resource.RLIMIT_NOFILE, (max_fds, resource.getrlimit(resource.RLIMIT_NOFILE)[1]))

class Age:
    def __init__(self):
        self.pool = None
        self.graphName = None
        self.name = "Pythonage"
        self.version = "0.0.1"
        self.description = "Pythonage is a Python package that helps you to create a graph database using Azure Database for PostgreSQL."
        self.author = "Rio Fujita"

    def __del__(self):
        self.pool.close()

    @classmethod
    def connect(cls, dsn : str = None, pool_max_size : int = 844, **kwargs) -> Self:
        increaseFileDescriptors(5000)
        cls.pool = ConnectionPool(dsn + " options='-c search_path=ag_catalog,\"$user\",public'", max_size = pool_max_size, min_size = 16, **kwargs)
        cls.pool.open()
        cls.pool.wait(timeout = 30.0)
        return cls

    @classmethod
    def loadFromDataframe(cls, graph_name : str = None, df : pd.DataFrame = None, relation : str = None) -> None:
        cls.setUpGraph(cls, graph_name)
        if len(keys := df.keys()) != 2:
            raise ValueError(f"Dataframe must have just 2 columns, but {len(keys)} columns were found.")
        df.sort_values(by = keys[0])
        for key in keys:
            df[key] = df[key].str.replace("'", r"\'")
            cls.createVertices(cls, key, df[key].unique())
        # fix me
        cls.createEdges(cls, keys, zip(df[keys[0]].tolist(), df[keys[1]].tolist()), relation)
        with cls.pool.connection() as conn:
            with conn.cursor() as cur:
                cur.execute("COMMIT;")

    @classmethod
    def loadFromCSV(cls, graph_name : str = None, csv_path : str = None) -> None:
        cls.setUpGraph(cls, graph_name)


    def setUpGraph(self, graph_name : str = None) -> None:
        with self.pool.connection() as conn:
            with conn.cursor(row_factory = namedtuple_row) as cur:
                cur.execute(sql.SQL(f"SELECT count(*) FROM ag_graph WHERE name={sql.Literal(graph_name).as_string()}"))
                if (row := cur.fetchone()) is not None:
                    if row.count == 1:
                        cur.execute(sql.SQL(f"SELECT drop_graph({sql.Literal(graph_name).as_string()}, true)"))
                    cur.execute(sql.SQL(f"SELECT create_graph({sql.Literal(graph_name).as_string()})"))
                self.graphName = graph_name

    def createVertices(self, key : str = None, nodes : list = None) -> None:
        with self.pool.connection() as conn:
            with conn.cursor() as cur:
                query = "".join([f"SELECT * FROM cypher({sql.Literal(self.graphName).as_string()}, $$ CREATE (n:{key} {{name: '{node}'}}) $$) AS (a agtype);" for node in nodes])
                cur.execute(query)
                cur.execute(f'CREATE INDEX ON {self.graphName}."{key}" USING GIN (properties);')
                cur.execute(f'CREATE INDEX ON {self.graphName}."{key}" USING BTREE (id);')

    def createEdges(self, keys : list = None, node_pairs : list = None, relation : str = None) -> None:
        threads = []
        threadId = 1

        for k, g in itertools.groupby(node_pairs, key = lambda x: x[0]):
            t = threading.Thread(target = self.createEdgesPerStarts, args = (self.pool, self.graphName, keys, k, list(g), relation, threadId))
            t.start()
            threads.append(t)
            threadId += 1

        for i in threads:
            i.join()

        with self.pool.connection() as conn:
            with conn.cursor() as cur:
                cur.execute(f'CREATE INDEX ON {self.graphName}."{relation}" (start_id);')
                cur.execute(f'CREATE INDEX ON {self.graphName}."{relation}" (end_id);')

    def createEdgesPerStarts(pool : ConnectionPool = None, graphName : str = None, keys : list = None, start_node : str = None, pairs : tuple = None, relation : str = None, threadId: int = 0) -> None:
        query_ends = ', '.join([f"(m{idx}:{keys[1]} {{name: '{pair[1]}'}})" for idx, pair in enumerate(pairs)])
        query_rels = ', '.join([f"(n)-[:{relation}]->(m{idx})" for idx, _ in enumerate(pairs)])
        query = f"SELECT * FROM cypher({sql.Literal(graphName).as_string()}, $$ MATCH (n:{keys[0]} {{name: '{start_node}'}}), {query_ends} CREATE {query_rels} $$) AS (a agtype);"
        # fix me
        while True:
            try:
                with pool.connection() as conn:
                    with conn.cursor() as cur:
                        cur.execute(query)
                        break
            except:
                pass

呼び出す側はこれだけ。

#!/usr/bin/env python3.11
# -*- coding: utf-8 -*-

import pandas as pd
from  pythonage import Age

AG_GRAPH_NAME = "actorfilms"

# Global Variables
connection_string = """
    host=your_server.postgres.database.azure.com
    port=5432
    dbname=postgres
    user=your_account
    password=admin_password
"""

df = pd.read_csv(f"{AG_GRAPH_NAME}.csv", usecols = ["Actor", "Film"])

ag = Age.connect(dsn = connection_string, pool_max_size = 256)
ag.loadFromDataframe(graph_name = AG_GRAPH_NAME, df = df, relation = "ACTED_IN")

ベンチした。インスタンスのサイズと、コネクションプールのコネクション数を変えながら、かかった時間。

Standard_D4ds_v4Standard_D8ds_v4Standard_D16ds_v4Standard_D32ds_v4
162:27.571:18.5643.08533.336
322:31.321:10.9042.14925.492
642:35.021:17.4043.39625.148
1282:29.921:16.0942.32824.107
1922:37.781:20.0342.72524.416
2562:38.551:16.5342.19424.476
3842:39.371:19.9743.04124.387
5122:38.771:20.5042.60924.493

データ容量は大したこと無いので、ほぼCPUのコア数だけで決まるかな。ベンチマーク中にサチってるのはCPUだけで、メモリやIOPSは余裕。