#268 Rework database connections system
Merged 4 years ago by puiterwijk. Opened 4 years ago by puiterwijk.
puiterwijk/ipsilon db-connections  into  master

@@ -79,7 +79,6 @@ 

          q = self._query(self._db, 'association', UNIQUE_DATA_TABLE,

                          trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

  

      def _upgrade_schema(self, old_version):

          if old_version == 1:

@@ -21,7 +21,6 @@ 

          q = self._query(self._db, 'client', OPTIONS_TABLE,

                          trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

  

      def _upgrade_schema(self, old_version):

          raise NotImplementedError()
@@ -404,15 +403,12 @@ 

          q = self._query(self._db, 'client', UNIQUE_DATA_TABLE,

                          trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

          q = self._query(self._db, 'token', UNIQUE_DATA_TABLE,

                          trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

          q = self._query(self._db, 'userinfo', UNIQUE_DATA_TABLE,

                          trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

  

      def _upgrade_schema(self, old_version):

          raise NotImplementedError()

file modified
+222 -160
@@ -52,7 +52,7 @@ 

      __instances = {}

  

      @classmethod

-     def get_connection(cls, name):

+     def get_instance(cls, name):

          if name not in cls.__instances:

              if cherrypy.config.get('db.conn.log', False):

                  logging.debug('SqlStore new: %s', name)
@@ -102,25 +102,99 @@ 

      def engine(self):

          return self._dbengine

  

-     def connection(self):

+     def connection(self, will_close=False):

+         """Function that makes a connection to the database.

+ 

+         will_close indicates whether the client will take responsibility of

+         closing the connection after it's done with it."""

          self.debug('SqlStore connect: %s' % self.name)

          conn = self._dbengine.connect()

  

          def cleanup_connection():

              self.debug('SqlStore cleanup: %s' % self.name)

              conn.close()

-         cherrypy.request.hooks.attach('on_end_request', cleanup_connection)

+         if not will_close:

+             cherrypy.request.hooks.attach('on_end_request', cleanup_connection)

          return conn

  

  

- class SqlQuery(Log):

+ class BaseQuery(Log):

+ 

+     def commit(self):

+         """Function to override to commit the transaction."""

+         pass

+ 

+     def rollback(self):

+         """Function to override to roll the transaction back."""

+         pass

+ 

+     def _setup_connection(self):

+         """Function to override to get a transaction and connection."""

+         pass

+ 

+     def _teardown_connection(self):

+         """Function to override to close transactions and connections."""

+         pass

+ 

+     def __enter__(self):

+         """Context Manager enter method.

+ 

+         This calls the setup connections method.

+         """

+         self._setup_connection()

+         return self

+ 

+     def __exit__(self, exc_class, exc, tb):

+         """ Context Manager exit method.

+ 

+         This automatically rolls back the transaction if an error occured and

+         the engine supports it, and otherwise runs the database commit method.

+         After this, it will run the teardown method.

+ 

+         All the arguments are defined by PEP#343.

+         """

+         if exc is None:

+             self.commit()

+         else:

+             self.rollback()

+         self._teardown_connection()

+ 

+ 

+ class SqlQuery(BaseQuery):

  

      def __init__(self, db_obj, table, table_def, trans=True):

          self._db = db_obj

-         self._con = self._db.connection()

-         self._trans = self._con.begin() if trans else None

+         self.__con = None

+         self._trans = None

+         self._use_trans = trans

          self._table = self._get_table(table, table_def)

  

+     def _setup_connection(self):

+         if self.__con is not None:

+             self.debug('Multiple setup_connection calls without teardown')

+             return

+         self.__con = self._db.connection(True)

+         self._trans = self._con.begin() if self._use_trans else None

+ 

+     def _teardown_connection(self):

+         self.__con.close()

+         self.__con = None

+         self._trans = None

It might be good to do something here if self._trans is not None, like raise an error, or abort the transaction, or both?

+ 

+     @property

+     def _con(self):

+         """Function that makes sure there is an active connection.

+ 

+         This is for backwards compatibility if other classes use SqlQuery

+         without the context manager handling."""

+         if not self.__con:

+             self.error('DEPRECATED: SqlQuery used without context manager!')

+             # Since we will not get notified when the user is done, we will

+             # need to get the conn closed after the request.

+             self.__con = self._db.connection(will_close=False)

+             self._trans = self._con.begin() if self._use_trans else None

+         return self.__con

+ 

      def _get_table(self, name, table_def):

          if isinstance(table_def, list):

              table_def = {'columns': table_def,
@@ -163,9 +237,13 @@ 

          return cols

  

      def rollback(self):

+         if not self._trans:

+             return

          self._trans.rollback()

  

      def commit(self):

+         if not self._trans:

+             return

          self._trans.commit()

  

      def create(self):
@@ -227,7 +305,7 @@ 

          raise NotImplementedError()

  

  

- class FileQuery(Log):

+ class FileQuery(BaseQuery):

  

      def __init__(self, fstore, table, table_def, trans=True):

          # We don't need indexes in a FileQuery, so drop that info
@@ -365,7 +443,7 @@ 

          return

  

  

- class EtcdQuery(Log):

+ class EtcdQuery(BaseQuery):

      """

      Class to store stuff in Etcd key-value stores.

  
@@ -633,7 +711,7 @@ 

              self._db = EtcdStore(name)

              self._query = EtcdQuery

          else:

-             self._db = SqlStore.get_connection(name)

+             self._db = SqlStore.get_instance(name)

              self._query = SqlQuery

  

          if not self._is_upgrade:
@@ -699,7 +777,8 @@ 

          for table in self._auto_cleanup_tables:

              self.debug('Auto-cleaning %s' % table)

              q = self._query(self._db, table, UNIQUE_DATA_TABLE)

-             cleaned_table = q.perform_auto_cleanup()

+             with q:

+                 cleaned_table = q.perform_auto_cleanup()

              self.debug('Cleaned up %i entries' % cleaned_table)

              cleaned += cleaned_table

          return cleaned
@@ -722,7 +801,6 @@ 

          #  the main codebase, and even in the same database.

          q = self._query(self._db, 'dbinfo', OPTIONS_TABLE, trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

          cls_name = self.__class__.__name__

          current_version = self.load_options('dbinfo').get('%s_schema'

                                                            % cls_name, {})
@@ -840,12 +918,13 @@ 

  

      def _load_data(self, table, columns, kvfilter=None):

          rows = []

-         try:

-             q = self._query(self._db, table, columns, trans=False)

-             rows = q.select(kvfilter)

-         except Exception, e:  # pylint: disable=broad-except

-             self.error("Failed to load data for table %s for store %s: [%s]"

-                        % (table, self.__class__.__name__, e))

+         q = self._query(self._db, table, columns, trans=False)

+         with q:

+             try:

+                 rows = q.select(kvfilter)

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error("Failed to load data for table %s for store %s:[%s]"

+                            % (table, self.__class__.__name__, e))

          return self._rows_to_dict_tree(rows)

  

      def load_config(self):
@@ -864,47 +943,42 @@ 

      def save_options(self, table, name, options):

          curvals = dict()

          q = None

-         try:

-             q = self._query(self._db, table, OPTIONS_TABLE)

-             rows = q.select({'name': name}, ['option', 'value'])

-             for row in rows:

-                 curvals[row[0]] = row[1]

- 

-             for opt in options:

-                 if opt in curvals:

-                     q.update({'value': options[opt]},

-                              {'name': name, 'option': opt})

-                 else:

-                     q.insert((name, opt, options[opt]))

+         q = self._query(self._db, table, OPTIONS_TABLE)

+         with q:

+             try:

+                 rows = q.select({'name': name}, ['option', 'value'])

+                 for row in rows:

+                     curvals[row[0]] = row[1]

  

-             for opt in curvals:

-                 if opt not in options:

-                     q.delete({'name': name, 'option': opt})

+                 for opt in options:

+                     if opt in curvals:

+                         q.update({'value': options[opt]},

+                                  {'name': name, 'option': opt})

+                     else:

+                         q.insert((name, opt, options[opt]))

  

-             q.commit()

-         except Exception, e:  # pylint: disable=broad-except

-             if q:

-                 q.rollback()

-             self.error("Failed to save options: [%s]" % e)

-             raise

+                 for opt in curvals:

+                     if opt not in options:

+                         q.delete({'name': name, 'option': opt})

+ 

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error("Failed to save options: [%s]" % e)

+                 raise

  

      def delete_options(self, table, name, options=None):

          kvfilter = {'name': name}

-         q = None

-         try:

-             q = self._query(self._db, table, OPTIONS_TABLE)

-             if options is None:

-                 q.delete(kvfilter)

-             else:

-                 for opt in options:

-                     kvfilter['option'] = opt

+         q = self._query(self._db, table, OPTIONS_TABLE)

+         with q:

+             try:

+                 if options is None:

                      q.delete(kvfilter)

-             q.commit()

-         except Exception, e:  # pylint: disable=broad-except

-             if q:

-                 q.rollback()

-             self.error("Failed to delete from %s: [%s]" % (table, e))

-             raise

+                 else:

+                     for opt in options:

+                         kvfilter['option'] = opt

+                         q.delete(kvfilter)

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error("Failed to delete from %s: [%s]" % (table, e))

+                 raise

  

      def new_unique_data(self, table, data, ttl=None, expiration_time=None):

          if expiration_time:
@@ -915,19 +989,16 @@ 

              raise ValueError('Negative TTL specified: %s' % ttl)

  

          newid = str(uuid.uuid4())

-         q = None

-         try:

-             q = self._query(self._db, table, UNIQUE_DATA_TABLE)

-             for name in data:

-                 q.insert((newid, name, data[name]), ttl)

-             if expiration_time:

-                 q.insert((newid, 'expiration_time', expiration_time), ttl)

-             q.commit()

-         except Exception, e:  # pylint: disable=broad-except

-             if q:

-                 q.rollback()

-             self.error("Failed to store %s data: [%s]" % (table, e))

-             raise

+         q = self._query(self._db, table, UNIQUE_DATA_TABLE)

+         with q:

+             try:

+                 for name in data:

+                     q.insert((newid, name, data[name]), ttl)

+                 if expiration_time:

+                     q.insert((newid, 'expiration_time', expiration_time), ttl)

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error("Failed to store %s data: [%s]" % (table, e))

+                 raise

          return newid

  

      def get_unique_data(self, table, uuidval=None, name=None, value=None):
@@ -948,55 +1019,53 @@ 

          if ttl and ttl < 0:

              raise ValueError('Negative TTL specified: %s' % ttl)

  

-         q = None

-         try:

-             q = self._query(self._db, table, UNIQUE_DATA_TABLE)

-             for uid in data:

-                 curvals = dict()

-                 rows = q.select({'uuid': uid}, ['name', 'value'])

-                 for r in rows:

-                     curvals[r[0]] = r[1]

- 

-                 datum = data[uid]

-                 if expiration_time:

-                     datum['expiration_time'] = expiration_time

-                 for name in datum:

-                     if name in curvals:

-                         if datum[name] is None:

-                             q.delete({'uuid': uid, 'name': name})

+         q = self._query(self._db, table, UNIQUE_DATA_TABLE)

+         with q:

+             try:

+                 for uid in data:

+                     curvals = dict()

+                     rows = q.select({'uuid': uid}, ['name', 'value'])

+                     for r in rows:

+                         curvals[r[0]] = r[1]

+ 

+                     datum = data[uid]

+                     if expiration_time:

+                         datum['expiration_time'] = expiration_time

+                     for name in datum:

+                         if name in curvals:

+                             if datum[name] is None:

+                                 q.delete({'uuid': uid, 'name': name})

+                             else:

+                                 q.update({'value': datum[name]},

+                                          {'uuid': uid, 'name': name})

                          else:

-                             q.update({'value': datum[name]},

-                                      {'uuid': uid, 'name': name})

-                     else:

-                         if datum[name] is not None:

-                             q.insert((uid, name, datum[name]), ttl)

- 

-             q.commit()

-         except Exception, e:  # pylint: disable=broad-except

-             if q:

-                 q.rollback()

-             self.error("Failed to store data in %s: [%s]" % (table, e))

-             raise

+                             if datum[name] is not None:

+                                 q.insert((uid, name, datum[name]), ttl)

+ 

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error("Failed to store data in %s: [%s]" % (table, e))

+                 raise

  

      def del_unique_data(self, table, uuidval):

          kvfilter = {'uuid': uuidval}

-         try:

-             q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)

-             q.delete(kvfilter)

-         except Exception, e:  # pylint: disable=broad-except

-             self.error("Failed to delete data from %s: [%s]" % (table, e))

+         q = self._query(self._db, table, UNIQUE_DATA_TABLE, trans=False)

+         with q:

+             try:

+                 q.delete(kvfilter)

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error("Failed to delete data from %s: [%s]" % (table, e))

  

      def _reset_data(self, table):

-         q = None

-         try:

-             q = self._query(self._db, table, UNIQUE_DATA_TABLE)

-             q.drop()

-             q.create()

-             q.commit()

-         except Exception, e:  # pylint: disable=broad-except

-             if q:

-                 q.rollback()

-             self.error("Failed to erase all data from %s: [%s]" % (table, e))

+         q = self._query(self._db, table, UNIQUE_DATA_TABLE)

+         with q:

+             try:

+                 q.drop()

+                 q.create()

+             except Exception, e:  # pylint: disable=broad-except

+                 if q:

+                     q.rollback()

+                 self.error("Failed to erase all data from %s: [%s]"

+                            % (table, e))

  

  

  class AdminStore(Store):
@@ -1031,7 +1100,6 @@ 

                        'authz_config']:

              q = self._query(self._db, table, OPTIONS_TABLE, trans=False)

              q.create()

-             q._con.close()  # pylint: disable=protected-access

  

      def _upgrade_schema(self, old_version):

          if old_version == 1:
@@ -1052,7 +1120,6 @@ 

              q = self._query(self._db, 'authz_config', OPTIONS_TABLE,

                              trans=False)

              q.create()

-             q._con.close()  # pylint: disable=protected-access

              self.save_options('authz_config', 'global', {'enabled': 'allow'})

              return 3

          else:
@@ -1064,7 +1131,6 @@ 

              q = self._query(self._db, table, UNIQUE_DATA_TABLE,

                              trans=False)

              q.create()

-             q._con.close()  # pylint: disable=protected-access

  

  

  class UserStore(Store):
@@ -1093,65 +1159,63 @@ 

  

      def store_consent(self, user, provider, clientid, parameters):

          q = None

-         try:

-             key = self._cons_key(provider, clientid)

-             q = self._query(self._db, 'user_consent', OPTIONS_TABLE)

-             rows = q.select({'name': user, 'option': key}, ['value'])

-             if len(list(rows)) > 0:

-                 q.update({'value': parameters}, {'name': user, 'option': key})

-             else:

-                 q.insert((user, key, parameters))

-             q.commit()

-         except Exception, e:  # pylint: disable=broad-except

-             if q:

-                 q.rollback()

-             self.error('Failed to store consent: [%s]' % e)

-             raise

+         q = self._query(self._db, 'user_consent', OPTIONS_TABLE)

+         with q:

+             try:

+                 key = self._cons_key(provider, clientid)

+                 rows = q.select({'name': user, 'option': key}, ['value'])

+                 if len(list(rows)) > 0:

+                     q.update({'value': parameters}, {'name': user,

+                                                      'option': key})

+                 else:

+                     q.insert((user, key, parameters))

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error('Failed to store consent: [%s]' % e)

+                 raise

  

      def delete_consent(self, user, provider, clientid):

-         q = None

-         try:

-             q = self._query(self._db, 'user_consent', OPTIONS_TABLE)

-             q.delete({'name': user,

-                       'option': self._cons_key(provider, clientid)})

-             q.commit()

-         except Exception, e:  # pylint: disable=broad-except

-             if q:

-                 q.rollback()

-             self.error('Failed to delete consent: [%s]' % e)

-             raise

+         q = self._query(self._db, 'user_consent', OPTIONS_TABLE)

+         with q:

+             try:

+                 q.delete({'name': user,

+                           'option': self._cons_key(provider, clientid)})

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error('Failed to delete consent: [%s]' % e)

+                 raise

  

      def get_consent(self, user, provider, clientid):

-         try:

-             q = self._query(self._db, 'user_consent', OPTIONS_TABLE)

-             rows = q.select({'name': user,

-                              'option': self._cons_key(provider, clientid)},

-                             ['value'])

-             data = list(rows)

-             if len(data) > 0:

-                 return data[0][0]

-             else:

-                 return None

-         except Exception, e:  # pylint: disable=broad-except

-             self.error('Failed to get consent: [%s]' % e)

-             return None

+         q = self._query(self._db, 'user_consent', OPTIONS_TABLE)

+         with q:

+             try:

+                 rows = q.select({'name': user,

+                                  'option': self._cons_key(provider, clientid)},

+                                 ['value'])

+                 data = list(rows)

+                 if len(data) > 0:

+                     return data[0][0]

+                 else:

+                     return None

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error('Failed to get consent: [%s]' % e)

+                 raise

  

      def get_all_consents(self, user):

          d = []

-         try:

-             q = self._query(self._db, 'user_consent', OPTIONS_TABLE)

-             rows = q.select({'name': user}, ['option', 'value'])

-             for r in rows:

-                 prov, clientid = self._split_cons_key(r[0])

-                 d.append((prov, clientid, r[1]))

-         except Exception, e:  # pylint: disable=broad-except

-             self.error('Failed to get consents: [%s]' % e)

+         q = self._query(self._db, 'user_consent', OPTIONS_TABLE)

+         with q:

+             try:

+                 rows = q.select({'name': user}, ['option', 'value'])

+                 for r in rows:

+                     prov, clientid = self._split_cons_key(r[0])

+                     d.append((prov, clientid, r[1]))

+             except Exception, e:  # pylint: disable=broad-except

+                 self.error('Failed to get consents: [%s]' % e)

+                 raise

          return d

  

      def _initialize_table(self, tablename):

          q = self._query(self._db, tablename, OPTIONS_TABLE, trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

  

      def _initialize_schema(self):

          self._initialize_table('users')
@@ -1191,7 +1255,6 @@ 

          q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,

                          trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

  

      def _upgrade_schema(self, old_version):

          if old_version == 1:
@@ -1290,7 +1353,6 @@ 

          q = self._query(self._db, self.table, UNIQUE_DATA_TABLE,

                          trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

  

      def _upgrade_schema(self, old_version):

          if old_version == 1:

file modified
+13 -17
@@ -27,7 +27,6 @@ 

          q = self._query(self._db, 'sessions', SESSION_TABLE,

                          trans=False)

          q.create()

-         q._con.close()  # pylint: disable=protected-access

  

      def _upgrade_schema(self, old_version):

          if old_version == 1:
@@ -75,33 +74,30 @@ 

  

      def _exists(self):

          q = SqlQuery(self._db, 'sessions', SESSION_TABLE)

-         result = q.select({'id': self.id})

-         return True if result.fetchone() else False

+         with q:

+             result = q.select({'id': self.id})

+             return True if result.fetchone() else False

  

      def _load(self):

          q = SqlQuery(self._db, 'sessions', SESSION_TABLE)

-         result = q.select({'id': self.id})

-         r = result.fetchone()

-         if r:

-             data = str(base64.b64decode(r[1]))

-             return pickle.loads(data)

+         with q:

+             result = q.select({'id': self.id})

+             r = result.fetchone()

+             if r:

+                 data = str(base64.b64decode(r[1]))

+                 return pickle.loads(data)

  

      def _save(self, expiration_time):

-         q = None

-         try:

-             q = SqlQuery(self._db, 'sessions', SESSION_TABLE, trans=True)

+         q = SqlQuery(self._db, 'sessions', SESSION_TABLE, trans=True)

+         with q:

              q.delete({'id': self.id})

              data = pickle.dumps((self._data, expiration_time), self._proto)

              q.insert((self.id, base64.b64encode(data), expiration_time))

-             q.commit()

-         except Exception:  # pylint: disable=broad-except

-             if q:

-                 q.rollback()

-             raise

  

      def _delete(self):

          q = SqlQuery(self._db, 'sessions', SESSION_TABLE)

-         q.delete({'id': self.id})

+         with q:

+             q.delete({'id': self.id})

  

      # copy what RamSession does for now

      def acquire_lock(self):

file modified
+1 -1
@@ -127,7 +127,7 @@ 

      try:

          page = sess.revoke_all_consent(idpname)

      except ValueError, e:

-         print >> sys.stderr, "" % repr(e)

+         print >> sys.stderr, " ERROR: %s" % repr(e)

          sys.exit(1)

      print " SUCCESS"

  

This new system will make sure that connections are closed (returned to the sqlalchemy pool) as soon as a specific query block is finished, rather than at the end of the full request.
This makes sure that even with a lot of modules enabled, we don't have lots of open database connections during a single request.

It might be good to do something here if self._trans is not None, like raise an error, or abort the transaction, or both?

Looks good to me, though I might suggest using an off the shelf connection manager rather than using a home made one. That would be a suggestion for another pull request of course, since this code already uses a home made one anyway.

We use the sqlalchemy built-in connection manager. "close"ing a connection just returns it back to the SA pool.

LGTM. I tend to do "with blah() as foo:" myself, but no objection to your style.

Commit 7f2f44f fixes this pull-request

Pull-Request has been merged by puiterwijk@redhat.com

4 years ago