163 lines
4.6 KiB
Python
163 lines
4.6 KiB
Python
"""
|
|
PostgreSQL connectivity.
|
|
|
|
PGPool can be used as context manager. It takes postgresql configuration
|
|
parameters and gives a connection pool.
|
|
"""
|
|
|
|
import logging
|
|
import sys
|
|
from io import TextIOBase
|
|
from pathlib import Path
|
|
from traceback import format_exc
|
|
from typing import Dict
|
|
|
|
import asyncpg
|
|
|
|
from .utils.json import json_dumps, json_loads
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PGPool:
|
|
"""
|
|
Database connectivity: Provide a connection pool.
|
|
|
|
Can be used either as async context manager (giving a pool),
|
|
or as a class using async init and the shutdown method and
|
|
having the pool attribute.
|
|
|
|
After startup self.pool contains a PostgreSQL connection pool
|
|
(instance of :class:`asyncpg.pool.Pool`).
|
|
|
|
Startup also runs schema migrations (cf. directory `migrations`).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
postgresql_config: dict,
|
|
out: TextIOBase = None,
|
|
check: bool = True,
|
|
) -> None:
|
|
self.conf = postgresql_config
|
|
self.out = out or sys.stdout
|
|
self.check = check
|
|
self.pool = None
|
|
|
|
def __await__(self):
|
|
return self.__ainit__().__await__()
|
|
|
|
async def __ainit__(self):
|
|
await self.__aenter__()
|
|
return self
|
|
|
|
async def __aenter__(self):
|
|
"""
|
|
Return the connection pool after an optional check.
|
|
|
|
The check tests basic database access and runs missing migrations.
|
|
If the check fails, return None.
|
|
"""
|
|
pool_params = {
|
|
key: val
|
|
for key, val in self.conf.items()
|
|
if key
|
|
in (
|
|
'host',
|
|
'port',
|
|
'database',
|
|
'user',
|
|
'password',
|
|
'max_size',
|
|
'min_size',
|
|
)
|
|
}
|
|
pool_params['command_timeout'] = 30
|
|
self.pool = await asyncpg.create_pool(**pool_params, init=self._init)
|
|
if self.check:
|
|
async with self.pool.acquire() as conn:
|
|
if await self.check_or_migrate(conn):
|
|
return self.pool
|
|
|
|
@staticmethod
|
|
async def _init(conn) -> None:
|
|
"""
|
|
Add JSON encoding and decoding to the given connection.
|
|
"""
|
|
await conn.set_type_codec(
|
|
'jsonb',
|
|
encoder=json_dumps,
|
|
decoder=json_loads,
|
|
schema='pg_catalog',
|
|
)
|
|
|
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
"""
|
|
Close the connection pool.
|
|
"""
|
|
await self.shutdown()
|
|
|
|
async def shutdown(self):
|
|
"""
|
|
Close the pool.
|
|
"""
|
|
await self.pool.close()
|
|
|
|
async def check_or_migrate(self, conn: asyncpg.Connection) -> bool:
|
|
"""
|
|
Check database connectivity.
|
|
|
|
Return whether database connectivity is working.
|
|
"""
|
|
row = await conn.fetchrow('SELECT 1+1 AS result')
|
|
if not row or row.get('result') != 2:
|
|
msg = 'Database SELECT 1+1 not working; missing privileges?'
|
|
print(msg, file=self.out)
|
|
logger.critical(msg)
|
|
return False
|
|
|
|
# determine current schema_version
|
|
try:
|
|
sql = "SELECT value::int FROM kvs WHERE key='schema_version'"
|
|
schema_version = await conn.fetchval(sql)
|
|
except:
|
|
schema_version = 0
|
|
|
|
# run missing migrations
|
|
migrations = get_migrations()
|
|
for number, text in sorted(migrations.items()):
|
|
if number > schema_version:
|
|
cmds = text.split('\n----\n')
|
|
for cmd in cmds:
|
|
if not cmd.strip():
|
|
continue
|
|
try:
|
|
await conn.execute(cmd)
|
|
except:
|
|
msg = (
|
|
f'Exception during migration {number} in '
|
|
f'statement\n{cmd}'
|
|
)
|
|
print(msg, file=self.out)
|
|
logger.critical(msg)
|
|
print(format_exc(), file=self.out)
|
|
logger.critical(format_exc())
|
|
return False
|
|
|
|
# return success
|
|
return True
|
|
|
|
|
|
def get_migrations() -> Dict[int, str]:
|
|
"""
|
|
Return migrations (number and text content of migration file).
|
|
"""
|
|
migrations_dir = Path(__file__).parent / 'migrations'
|
|
migrations = {}
|
|
for migration_file in migrations_dir.glob('*.sql'):
|
|
migration_number = int(migration_file.name[:-4])
|
|
with migration_file.open() as mig_file:
|
|
content = mig_file.read()
|
|
migrations[migration_number] = content
|
|
return migrations
|