from ephemeral_port_reserve import reserve from glob import glob import logging import os import psycopg2 import random import re import shutil import signal import sqlite3 import string import subprocess import time class Sqlite3Db(object): def __init__(self, path): self.path = path def get_dsn(self): """SQLite3 doesn't provide a DSN, resulting in no CLI-option. """ return None def query(self, query): orig = os.path.join(self.path) copy = self.path + ".copy" shutil.copyfile(orig, copy) db = sqlite3.connect(copy) db.row_factory = sqlite3.Row c = db.cursor() c.execute(query) rows = c.fetchall() result = [] for row in rows: result.append(dict(zip(row.keys(), row))) db.commit() c.close() db.close() return result def execute(self, query): db = sqlite3.connect(self.path) c = db.cursor() c.execute(query) db.commit() c.close() db.close() class PostgresDb(object): def __init__(self, dbname, port): self.dbname = dbname self.port = port self.conn = psycopg2.connect("dbname={dbname} user=postgres host=localhost port={port}".format( dbname=dbname, port=port )) cur = self.conn.cursor() cur.execute('SELECT 1') cur.close() def get_dsn(self): return "postgres://postgres:password@localhost:{port}/{dbname}".format( port=self.port, dbname=self.dbname ) def query(self, query): cur = self.conn.cursor() cur.execute(query) # Collect the results into a list of dicts. res = [] for r in cur: t = {} # Zip the column definition with the value to get its name. for c, v in zip(cur.description, r): t[c.name] = v res.append(t) cur.close() return res def execute(self, query): with self.conn, self.conn.cursor() as cur: cur.execute(query) class SqliteDbProvider(object): def __init__(self, directory): self.directory = directory def start(self): pass def get_db(self, node_directory, testname, node_id): path = os.path.join( node_directory, 'lightningd.sqlite3' ) return Sqlite3Db(path) def stop(self): pass class PostgresDbProvider(object): def __init__(self, directory): self.directory = directory self.port = None self.proc = None print("Starting PostgresDbProvider") def locate_path(self): prefix = '/usr/lib/postgresql/*' matches = glob(prefix) candidates = {} for m in matches: g = re.search(r'([0-9]+[\.0-9]*)', m) if not g: continue candidates[float(g.group(1))] = m if len(candidates) == 0: raise ValueError("Could not find `postgres` and `initdb` binaries in {}. Is postgresql installed?".format(prefix)) # Now iterate in reverse order through matches for k, v in sorted(candidates.items())[::-1]: initdb = os.path.join(v, 'bin', 'initdb') postgres = os.path.join(v, 'bin', 'postgres') if os.path.isfile(initdb) and os.path.isfile(postgres): logging.info("Found `postgres` and `initdb` in {}".format(os.path.join(v, 'bin'))) return initdb, postgres raise ValueError("Could not find `postgres` and `initdb` in any of the possible paths: {}".format(candidates.values())) def start(self): passfile = os.path.join(self.directory, "pgpass.txt") self.pgdir = os.path.join(self.directory, 'pgsql') # Need to write a tiny file containing the password so `initdb` can pick it up with open(passfile, 'w') as f: f.write('cltest\n') initdb, postgres = self.locate_path() subprocess.check_call([ initdb, '--pwfile={}'.format(passfile), '--pgdata={}'.format(self.pgdir), '--auth=trust', '--username=postgres', ]) self.port = reserve() self.proc = subprocess.Popen([ postgres, '-k', '/tmp/', # So we don't use /var/lib/... '-D', self.pgdir, '-p', str(self.port), '-F', '-i', ]) # Hacky but seems to work ok (might want to make the postgres proc a TailableProc as well if too flaky). time.sleep(1) self.conn = psycopg2.connect("dbname=template1 user=postgres host=localhost port={}".format(self.port)) # Required for CREATE DATABASE to work self.conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) def get_db(self, node_directory, testname, node_id): # Random suffix to avoid collisions on repeated tests nonce = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8)) dbname = "{}_{}_{}".format(testname, node_id, nonce) cur = self.conn.cursor() cur.execute("CREATE DATABASE {};".format(dbname)) cur.close() db = PostgresDb(dbname, self.port) return db def stop(self): # Send fast shutdown signal see [1] for details: # # SIGINT # # This is the Fast Shutdown mode. The server disallows new connections # and sends all existing server processes SIGTERM, which will cause # them to abort their current transactions and exit promptly. It then # waits for all server processes to exit and finally shuts down. If # the server is in online backup mode, backup mode will be terminated, # rendering the backup useless. # # [1] https://www.postgresql.org/docs/9.1/server-shutdown.html self.proc.send_signal(signal.SIGINT) self.proc.wait()