atextcrawler/src/atextcrawler/db.py

163 lines
4.6 KiB
Python

"""
PostgreSQL connectivity.
PGPool can be used as context manager. It takes postgresql configuration
parameters and gives a connection pool.
"""
import logging
import sys
from io import TextIOBase
from pathlib import Path
from traceback import format_exc
from typing import Dict
import asyncpg
from .utils.json import json_dumps, json_loads
logger = logging.getLogger(__name__)
class PGPool:
"""
Database connectivity: Provide a connection pool.
Can be used either as async context manager (giving a pool),
or as a class using async init and the shutdown method and
having the pool attribute.
After startup self.pool contains a PostgreSQL connection pool
(instance of :class:`asyncpg.pool.Pool`).
Startup also runs schema migrations (cf. directory `migrations`).
"""
def __init__(
self,
postgresql_config: dict,
out: TextIOBase = None,
check: bool = True,
) -> None:
self.conf = postgresql_config
self.out = out or sys.stdout
self.check = check
self.pool = None
def __await__(self):
return self.__ainit__().__await__()
async def __ainit__(self):
await self.__aenter__()
return self
async def __aenter__(self):
"""
Return the connection pool after an optional check.
The check tests basic database access and runs missing migrations.
If the check fails, return None.
"""
pool_params = {
key: val
for key, val in self.conf.items()
if key
in (
'host',
'port',
'database',
'user',
'password',
'max_size',
'min_size',
)
}
pool_params['command_timeout'] = 30
self.pool = await asyncpg.create_pool(**pool_params, init=self._init)
if self.check:
async with self.pool.acquire() as conn:
if await self.check_or_migrate(conn):
return self.pool
@staticmethod
async def _init(conn) -> None:
"""
Add JSON encoding and decoding to the given connection.
"""
await conn.set_type_codec(
'jsonb',
encoder=json_dumps,
decoder=json_loads,
schema='pg_catalog',
)
async def __aexit__(self, exc_type, exc, tb) -> None:
"""
Close the connection pool.
"""
await self.shutdown()
async def shutdown(self):
"""
Close the pool.
"""
await self.pool.close()
async def check_or_migrate(self, conn: asyncpg.Connection) -> bool:
"""
Check database connectivity.
Return whether database connectivity is working.
"""
row = await conn.fetchrow('SELECT 1+1 AS result')
if not row or row.get('result') != 2:
msg = 'Database SELECT 1+1 not working; missing privileges?'
print(msg, file=self.out)
logger.critical(msg)
return False
# determine current schema_version
try:
sql = "SELECT value::int FROM kvs WHERE key='schema_version'"
schema_version = await conn.fetchval(sql)
except:
schema_version = 0
# run missing migrations
migrations = get_migrations()
for number, text in sorted(migrations.items()):
if number > schema_version:
cmds = text.split('\n----\n')
for cmd in cmds:
if not cmd.strip():
continue
try:
await conn.execute(cmd)
except:
msg = (
f'Exception during migration {number} in '
f'statement\n{cmd}'
)
print(msg, file=self.out)
logger.critical(msg)
print(format_exc(), file=self.out)
logger.critical(format_exc())
return False
# return success
return True
def get_migrations() -> Dict[int, str]:
"""
Return migrations (number and text content of migration file).
"""
migrations_dir = Path(__file__).parent / 'migrations'
migrations = {}
for migration_file in migrations_dir.glob('*.sql'):
migration_number = int(migration_file.name[:-4])
with migration_file.open() as mig_file:
content = mig_file.read()
migrations[migration_number] = content
return migrations