Source code for sirbot.plugins.postgres.plugin

import os
import logging

import asyncpg
import aiofiles
from aiocontext import async_contextmanager
import ujson

LOG = logging.getLogger(__name__)

[docs]class PgPlugin: """ Handle database connection and sql migration for postgresql. Database migration to new version are automatically handled at startup when the ``version`` and ``sql_migration_directory`` argument are passed. The ``sql_migration_directory`` should be a directory with a single sql file per version and ``version`` should follow semantic versioning. Args: version: Current version of the bot. sql_migration_directory: Directory where migration sql files are located. **kwargs: Arguments for :func:`asyncpg.pool.create_pool`. **Variables**: * **pool**: Instance of :class:`asyncpg.pool.Pool`. """ __name__ = "pg" def __init__(self, *, sql_migration_directory=None, version=None, **kwargs): self.pool_kwargs = kwargs self.pool = None self.version = version if sql_migration_directory: if not os.path.isabs(sql_migration_directory): sql_migration_directory = os.path.abspath(sql_migration_directory) self.sql_migration_directory = sql_migration_directory else: self.sql_migration_directory = None def load(self, sirbot):"Loading postgres plugin") sirbot.on_startup.insert(0, self.startup) sirbot.on_shutdown.append(self.shutdown) async def startup(self, sirbot): self.pool = await asyncpg.create_pool( **self.pool_kwargs, init=self._init_connection ) if self.sql_migration_directory and self.version: await self.migrate() async def shutdown(self, sirbot): await self.pool.close()
[docs] @async_contextmanager async def connection(self): """ Acquire a connection from the pool :return: Instance of :class:`asyncpg.connection.Connection` """ async with self.pool.acquire() as pg_con: yield pg_con
async def migrate(self):"Start of database migration") current_version = [int(n) for n in self.version.split(".")] async with self.connection() as connection: old_version = await self._check_database_version(connection) if current_version != old_version: async with connection.transaction(): if old_version is None: await self._init_database(connection) old_version = [0, 0, 0] for version in self._find_update_version( start=old_version, end=current_version ): await self._execute_sql_file(connection, version) await self._update_db_version(connection, current_version)"End of database migration") def _find_update_version(self, start, end): files = [] for file in os.listdir(self.sql_migration_directory): if file == "init.sql": continue name, _ = os.path.splitext(file) file_version = [int(n) for n in name.split(".")] if end >= file_version > start: files.append(file_version) files = sorted(files) files = [".".join(str(l) for l in f) for f in files] LOG.debug("Database migration versions: %s", files) return files async def _init_database(self, connection): LOG.debug("Executing initial migration") await connection.execute( """ CREATE TABLE metadata (db_version TEXT); INSERT INTO metadata (db_version) VALUES ('0.0.0'); """ ) if os.path.exists(os.path.join(self.sql_migration_directory, "init.sql")): await self._execute_sql_file(connection, "init") async def _execute_sql_file(self, connection, version): LOG.debug("Database migration to version %s: STARTED", version) async with os.path.join(self.sql_migration_directory, f"{version}.sql"), mode="r" ) as f: await connection.execute((await LOG.debug("Database migration to version %s: OK", version) @staticmethod async def _check_database_version(connection): try: metadata = await connection.fetchrow("""SELECT * FROM metadata""") return [int(n) for n in metadata["db_version"].split(".")] except asyncpg.exceptions.UndefinedTableError: LOG.debug('No "metadata" table found in database') @staticmethod async def _update_db_version(connection, version): await connection.execute( """UPDATE metadata SET db_version=$1""", ".".join(str(l) for l in version) ) async def _init_connection(self, connection): await connection.set_type_codec( "jsonb", encoder=self._json_encoder, decoder=self._json_decoder, schema="pg_catalog", ) @staticmethod def _json_encoder(value): return ujson.dumps(value) @staticmethod def _json_decoder(value): return ujson.loads(value)