""" PostgreSQL connectivity. PGPool can be used as context manager. It takes postgresql configuration parameters and gives a connection pool. """ import logging import sys from io import TextIOBase from pathlib import Path from traceback import format_exc from typing import Dict import asyncpg from .utils.json import json_dumps, json_loads logger = logging.getLogger(__name__) class PGPool: """ Database connectivity: Provide a connection pool. Can be used either as async context manager (giving a pool), or as a class using async init and the shutdown method and having the pool attribute. After startup self.pool contains a PostgreSQL connection pool (instance of :class:`asyncpg.pool.Pool`). Startup also runs schema migrations (cf. directory `migrations`). """ def __init__( self, postgresql_config: dict, out: TextIOBase = None, check: bool = True, ) -> None: self.conf = postgresql_config self.out = out or sys.stdout self.check = check self.pool = None def __await__(self): return self.__ainit__().__await__() async def __ainit__(self): await self.__aenter__() return self async def __aenter__(self): """ Return the connection pool after an optional check. The check tests basic database access and runs missing migrations. If the check fails, return None. """ pool_params = { key: val for key, val in self.conf.items() if key in ( 'host', 'port', 'database', 'user', 'password', 'max_size', 'min_size', ) } pool_params['command_timeout'] = 30 self.pool = await asyncpg.create_pool(**pool_params, init=self._init) if self.check: async with self.pool.acquire() as conn: if await self.check_or_migrate(conn): return self.pool @staticmethod async def _init(conn) -> None: """ Add JSON encoding and decoding to the given connection. """ await conn.set_type_codec( 'jsonb', encoder=json_dumps, decoder=json_loads, schema='pg_catalog', ) async def __aexit__(self, exc_type, exc, tb) -> None: """ Close the connection pool. """ await self.shutdown() async def shutdown(self): """ Close the pool. """ await self.pool.close() async def check_or_migrate(self, conn: asyncpg.Connection) -> bool: """ Check database connectivity. Return whether database connectivity is working. """ row = await conn.fetchrow('SELECT 1+1 AS result') if not row or row.get('result') != 2: msg = 'Database SELECT 1+1 not working; missing privileges?' print(msg, file=self.out) logger.critical(msg) return False # determine current schema_version try: sql = "SELECT value::int FROM kvs WHERE key='schema_version'" schema_version = await conn.fetchval(sql) except: schema_version = 0 # run missing migrations migrations = get_migrations() for number, text in sorted(migrations.items()): if number > schema_version: cmds = text.split('\n----\n') for cmd in cmds: if not cmd.strip(): continue try: await conn.execute(cmd) except: msg = ( f'Exception during migration {number} in ' f'statement\n{cmd}' ) print(msg, file=self.out) logger.critical(msg) print(format_exc(), file=self.out) logger.critical(format_exc()) return False # return success return True def get_migrations() -> Dict[int, str]: """ Return migrations (number and text content of migration file). """ migrations_dir = Path(__file__).parent / 'migrations' migrations = {} for migration_file in migrations_dir.glob('*.sql'): migration_number = int(migration_file.name[:-4]) with migration_file.open() as mig_file: content = mig_file.read() migrations[migration_number] = content return migrations