diff --git a/playhouse/pool.py b/playhouse/pool.py index 360dcabe..945cc222 100644 --- a/playhouse/pool.py +++ b/playhouse/pool.py @@ -4,7 +4,6 @@ import logging import threading import time from collections import namedtuple -from itertools import chain from peewee import MySQLDatabase from peewee import PostgresqlDatabase @@ -47,11 +46,18 @@ class PooledDatabase(object): if self._wait_timeout == 0: self._wait_timeout = float('inf') + # Lock for pool operations and condition for notifying when connection + # is released back to pool. self._pool_lock = threading.RLock() + self._pool_available = threading.Condition(self._pool_lock) # Available / idle connections stored in a heap, sorted oldest first. self._connections = [] + # Counter used for tie-breaker in heap (so we don't try comparing + # connection against connection). + self._heap_counter = 0 + # Mapping of connection id to PoolConnection. Ordinarily we would want # to use something like a WeakKeyDictionary, but Python typically won't # allow us to create weak references to connection objects. @@ -81,58 +87,51 @@ class PooledDatabase(object): if not self._wait_timeout: return super(PooledDatabase, self).connect(reuse_if_open) - expires = time.time() + self._wait_timeout - while expires > time.time(): + deadline = time.monotonic() + self._wait_timeout + while True: try: - ret = super(PooledDatabase, self).connect(reuse_if_open) + return super(PooledDatabase, self).connect(reuse_if_open) except MaxConnectionsExceeded: - time.sleep(0.1) - else: - return ret - raise MaxConnectionsExceeded('Max connections exceeded, timed out ' - 'attempting to connect.') + remaining = deadline - time.monotonic() + if remaining <= 0: + raise MaxConnectionsExceeded( + 'Max connections exceeded, timed out attempting to ' + 'connect.') + with self._pool_available: + self._pool_available.wait(timeout=min(remaining, 1.0)) @locked def _connect(self): - while True: + while self._connections: try: # Remove the oldest connection from the heap. - ts, _, c_conn = heapq.heappop(self._connections) - conn = c_conn - key = self.conn_key(conn) + ts, _counter, conn = heapq.heappop(self._connections) except IndexError: - ts = conn = None - logger.debug('No connection available in pool.') break - else: - if self._is_closed(conn): - # This connecton was closed, but since it was not stale - # it got added back to the queue of available conns. We - # then closed it and marked it as explicitly closed, so - # it's safe to throw it away now. - # (Because Database.close() calls Database._close()). - logger.debug('Connection %s was closed.', key) - ts = conn = None - elif self._stale_timeout and self._is_stale(ts): - # If we are attempting to check out a stale connection, - # then close it. We don't need to mark it in the "closed" - # set, because it is not in the list of available conns - # anymore. - logger.debug('Connection %s was stale, closing.', key) - self._close(conn, True) - ts = conn = None - else: - break - if conn is None: - if self._max_connections and ( - len(self._in_use) >= self._max_connections): - raise MaxConnectionsExceeded('Exceeded maximum connections.') - conn = super(PooledDatabase, self)._connect() - ts = time.time() key = self.conn_key(conn) - logger.debug('Created new connection %s.', key) + if self._is_closed(conn): + # Connection closed either by user or by driver - discard. + logger.debug('Connection %s was closed, discarding.', key) + continue + if self._stale_timeout and self._is_stale(ts): + logger.debug('Connection %s was stale, closing.', key) + self._close_raw(conn) + continue + + # Connection OK to use. + self._in_use[key] = PoolConnection(ts, conn, time.time()) + return conn + + if self._max_connections and ( + len(self._in_use) >= self._max_connections): + raise MaxConnectionsExceeded('Exceeded maximum connections.') + + conn = super(PooledDatabase, self)._connect() + ts = time.time() + key = self.conn_key(conn) + logger.debug('Created new connection %s.', key) self._in_use[key] = PoolConnection(ts, conn, time.time()) return conn @@ -148,22 +147,43 @@ class PooledDatabase(object): # Called on check-in to make sure the connection can be re-used. return True + def _close_raw(self, conn): + try: + super(PooledDatabase, self)._close(conn) + except Exception: + logger.debug('Error closing connection %s.', self.conn_key(conn), + exc_info=True) + @locked def _close(self, conn, close_conn=False): + # if close_conn == True, close underlying driver connection and remove + # from _in_use tracking. Do not return to available conns. key = self.conn_key(conn) + if close_conn: - super(PooledDatabase, self)._close(conn) - elif key in self._in_use: - pool_conn = self._in_use.pop(key) - if self._stale_timeout and self._is_stale(pool_conn.timestamp): - logger.debug('Closing stale connection %s.', key) - super(PooledDatabase, self)._close(conn) - elif self._can_reuse(conn): - logger.debug('Returning %s to pool.', key) - heapq.heappush(self._connections, - (pool_conn.timestamp, _sentinel(), conn)) - else: - logger.debug('Closed %s.', key) + self._in_use.pop(key, None) + self._close_raw(conn) + return + + if key not in self._in_use: + logger.debug('Connection %s not in use, ignoring close.', key) + return + + pool_conn = self._in_use.pop(key) + if self._stale_timeout and self._is_stale(pool_conn.timestamp): + logger.debug('Closing stale connection %s on check-in.', key) + self._close_raw(conn) + elif not self._can_reuse(conn): + logger.debug('Connection %s not reusable, closing.', key) + self._close_raw(conn) + else: + logger.debug('Returning %s to pool.', key) + self._heap_counter += 1 + heapq.heappush(self._connections, + (pool_conn.timestamp, self._heap_counter, conn)) + + # Wake up thread that may be waiting on connection. + self._pool_available.notify() @locked def manual_close(self): @@ -175,36 +195,36 @@ class PooledDatabase(object): # Obtain reference to the connection in-use by the calling thread. conn = self.connection() + key = self.conn_key(conn) - # A connection will only be re-added to the available list if it is - # marked as "in use" at the time it is closed. We will explicitly - # remove it from the "in use" list, call "close()" for the - # side-effects, and then explicitly close the connection. - self._in_use.pop(self.conn_key(conn), None) + # Remove from _in_use so that subsequent self.close() won't try to + # restore it to the pool. + self._in_use.pop(key, None) self.close() - self._close(conn, close_conn=True) + self._close_raw(conn) @locked def close_idle(self): # Close any open connections that are not currently in-use. - for _, _, conn in self._connections: - self._close(conn, close_conn=True) + idle = self._connections self._connections = [] + for _, _, conn in idle: + self._close_raw(conn) @locked def close_stale(self, age=600): # Close any connections that are in-use but were checked out quite some - # time ago and can be considered stale. - in_use = {} + # time ago and can be considered stale. May close connections in use by + # running threads. cutoff = time.time() - age n = 0 - for key, pool_conn in self._in_use.items(): + for key, pool_conn in list(self._in_use.items()): if pool_conn.checked_out < cutoff: - self._close(pool_conn.connection, close_conn=True) + self._close_raw(pool_conn.connection) + del self._in_use[key] n += 1 - else: - in_use[key] = pool_conn - self._in_use = in_use + + self._pool_available.notify_all() return n @locked @@ -212,12 +232,12 @@ class PooledDatabase(object): # Close all connections -- available and in-use. Warning: may break any # active connections used by other threads. self.close() - for _, _, conn in self._connections: - self._close(conn, close_conn=True) - for pool_conn in self._in_use.values(): - self._close(pool_conn.connection, close_conn=True) - self._connections = [] - self._in_use = {} + self.close_idle() + in_use, self._in_use = self._in_use, {} + for pool_conn in in_use.values(): + self._close_raw(pool_conn.connection) + + self._pool_available.notify_all() class _PooledMySQLDatabase(PooledDatabase): @@ -230,8 +250,7 @@ class _PooledMySQLDatabase(PooledDatabase): conn.ping(*args) except: return True - else: - return False + return False class PooledMySQLDatabase(_PooledMySQLDatabase, MySQLDatabase): pass @@ -256,8 +275,7 @@ class _PooledSqliteDatabase(PooledDatabase): conn.total_changes except: return True - else: - return False + return False class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase): pass diff --git a/tests/pool.py b/tests/pool.py index e7da610c..2a3829c5 100644 --- a/tests/pool.py +++ b/tests/pool.py @@ -62,6 +62,12 @@ class PooledTestDatabase(PooledDatabase, SqliteDatabase): pass +def push_conn(db, timestamp, conn): + # Push a connection onto the pool heap with a proper monotonic counter. + db._heap_counter += 1 + heapq.heappush(db._connections, (timestamp, db._heap_counter, conn)) + + class TestPooledDatabase(BaseTestCase): def setUp(self): super(TestPooledDatabase, self).setUp() @@ -183,9 +189,10 @@ class TestPooledDatabase(BaseTestCase): db = FakePooledDatabase('testing', counter=3) now = time.time() - heapq.heappush(db._connections, (now - 10, None, 3)) - heapq.heappush(db._connections, (now - 5, None, 2)) - heapq.heappush(db._connections, (now - 1, None, 1)) + now = time.time() + push_conn(db, now - 10, 3) + push_conn(db, now - 5, 2) + push_conn(db, now - 1, 1) self.assertEqual(db.connection(), 3) self.assertTrue(3 in db._in_use) @@ -217,9 +224,9 @@ class TestPooledDatabase(BaseTestCase): db = FakePooledDatabase('testing', counter=3) now = time.time() - heapq.heappush(db._connections, (now - 10, None, 3)) - heapq.heappush(db._connections, (now - 5, None, 2)) - heapq.heappush(db._connections, (now - 1, None, 1)) + push_conn(db, now - 10, 3) + push_conn(db, now - 5, 2) + push_conn(db, now - 1, 1) self.assertEqual(db.connection(), 3) self.assertTrue(3 in db._in_use) @@ -233,18 +240,19 @@ class TestPooledDatabase(BaseTestCase): now = time.time() db = FakePooledDatabase('testing', stale_timeout=10) conns = [ - (now - 20, None, 1), - (now - 15, None, 2), - (now - 5, None, 3), - (now, None, 4), + (now - 20, 1), + (now - 15, 2), + (now - 5, 3), + (now, 4), ] - for ts_conn in conns: - heapq.heappush(db._connections, ts_conn) + for ts, conn in conns: + push_conn(db, ts, conn) self.assertEqual(db.connection(), 3) self.assertEqual(len(db._in_use), 1) self.assertTrue(3 in db._in_use) - self.assertEqual(db._connections, [(now, None, 4)]) + self.assertEqual(len(db._connections), 1) + self.assertEqual(db._connections[0][2], 4) def test_connect_cascade(self): now = time.time() @@ -255,14 +263,14 @@ class TestPooledDatabase(BaseTestCase): db = ClosedPooledDatabase('testing', stale_timeout=10) conns = [ - (now - 15, None, 1), # Skipped due to being stale. - (now - 5, None, 2), # Will appear closed. - (now - 3, None, 3), - (now, None, 4), # Will appear closed. + (now - 15, 1), # Skipped due to being stale. + (now - 5, 2), # Will appear closed. + (now - 3, 3), + (now, 4), # Will appear closed. ] db.counter = 4 # The next connection we create will have id=5. - for ts_conn in conns: - heapq.heappush(db._connections, ts_conn) + for ts, conn in conns: + push_conn(db, ts, conn) # Conn 3 is not stale or closed, so we will get it. self.assertEqual(db.connection(), 3) @@ -271,7 +279,10 @@ class TestPooledDatabase(BaseTestCase): pool_conn = db._in_use[3] self.assertEqual(pool_conn.timestamp, now - 3) self.assertEqual(pool_conn.connection, 3) - self.assertEqual(db._connections, [(now, None, 4)]) + + # Only conn 4 remains in the idle pool. + self.assertEqual(len(db._connections), 1) + self.assertEqual(db._connections[0][2], 4) # Since conn 4 is closed, we will open a new conn. db._state.closed = True # Pretend we're in a different thread. @@ -314,6 +325,340 @@ class TestPooledDatabase(BaseTestCase): self.assertEqual(len(self.db._connections), 5) self.assertEqual(len(self.db._in_use), 0) + def test_heap_counter_deterministic_ordering(self): + # Verify that connections pushed with the same timestamp are returned + # in order. + now = time.time() + push_conn(self.db, now, 'a') + push_conn(self.db, now, 'b') + push_conn(self.db, now, 'c') + + results = [] + while self.db._connections: + ts, counter, conn = heapq.heappop(self.db._connections) + results.append(conn) + self.assertEqual(results, ['a', 'b', 'c']) + + def test_close_conn_removes_from_in_use(self): + # _close(conn, close_conn=True) should pop the key from _in_use AND + # close the underlying driver conn. + self.assertEqual(self.db.connection(), 1) + self.assertTrue(1 in self.db._in_use) + + closed_before = self.db.closed_counter + self.db._close(1, close_conn=True) + + self.assertNotIn(1, self.db._in_use) + self.assertEqual(self.db.closed_counter, closed_before + 1) + + def test_double_close_is_noop(self): + # Calling _close on a connection not in _in_use (and close_conn=False) + # should be a safe no-op rather than raising or leaking. + self.assertEqual(self.db.connection(), 1) + self.db.close() # Returns conn 1 to the pool. + + self.assertNotIn(1, self.db._in_use) + closed_before = self.db.closed_counter + # Second close should do nothing. + self.db._close(1) + self.assertEqual(self.db.closed_counter, closed_before) + # Pool state unchanged. + self.assertEqual(len(self.db._connections), 1) + + def test_can_reuse_false_closes_connection(self): + # When _can_reuse returns False on check-in, the connection should be + # closed at the driver level and not returned to the pool. + class NotReusablePooledDatabase(FakePooledDatabase): + def _can_reuse(self, conn): + return False + + db = NotReusablePooledDatabase('testing') + self.assertEqual(db.connection(), 1) + closed_before = db.closed_counter + + db.close() + + # Connection should have been driver-closed, not pooled. + self.assertEqual(db.closed_counter, closed_before + 1) + self.assertEqual(len(db._connections), 0) + self.assertEqual(db._in_use, {}) + + # Next connect creates a brand new connection. + self.assertEqual(db.connection(), 2) + + def test_close_raw_swallows_exception(self): + called = [] + # _close_raw should not propagate exceptions from the driver. + class BrokenDriverClose(FakeDatabase): + def _close(self, conn): + called.append(conn) + raise RuntimeError('failed') + + class BrokenPool(FakePooledDatabase, BrokenDriverClose): + pass + + db = BrokenPool('testing') + db._close_raw(1337) + self.assertEqual(called, [1337]) + + def test_close_stale_removes_from_in_use(self): + # Verify that close_stale both driver-closes the connection AND + # removes it from _in_use (no dangling keys). + db = FakePooledDatabase('testing', counter=2) + + now = time.time() + db._in_use[1] = PoolConnection(now - 1000, 1, now - 1000) + db._in_use[2] = PoolConnection(now, 2, now) + + closed_before = db.closed_counter + self.assertEqual(db.close_stale(age=500), 1) + self.assertNotIn(1, db._in_use) + self.assertIn(2, db._in_use) + self.assertEqual(db.closed_counter, closed_before + 1) + + def test_close_all_clears_both_pools(self): + # close_all should leave both _connections and _in_use completely + # empty, and driver-close every connection. + db = FakePooledDatabase('testing', counter=3) + + now = time.time() + push_conn(db, now - 5, 1) + push_conn(db, now - 1, 2) + + # Simulate two in-use connections. + db._in_use[3] = PoolConnection(now, 3, now) + db._in_use[4] = PoolConnection(now, 4, now) + + # One more for the "current thread" via normal connect path so + # self.close() inside close_all has something to reset. + db._state.closed = True + db.connect() + conn = db.connection() + self.assertIn(db.conn_key(conn), db._in_use) + + closed_before = db.closed_counter + db.close_all() + + self.assertEqual(db._connections, []) + self.assertEqual(db._in_use, {}) + # 2 idle + 2 manually-added in_use + the current thread's conn = 5. + # (close_all calls self.close() which triggers _close for the current + # thread's conn, but that goes through the return-to-pool path, not + # _close_raw. The subsequent loop over the snapshot handles it.) + self.assertGreaterEqual(db.closed_counter, closed_before + 4) + + def test_connect_timeout_with_condition_variable(self): + # Verify that connect() with a timeout raises after the timeout + # expires when the pool is exhausted. + db = FakePooledDatabase('testing', max_connections=1, timeout=0.15) + self.assertEqual(db.connection(), 1) + + errors = [] + def try_connect(): + db._state.closed = True # Appear as a new thread. + try: + db.connect() + except MaxConnectionsExceeded: + errors.append(True) + + t = threading.Thread(target=try_connect) + start = time.monotonic() + t.start() + t.join(timeout=2) + elapsed = time.monotonic() - start + + # Should have waited roughly the timeout duration. + self.assertEqual(len(errors), 1) + self.assertGreaterEqual(elapsed, 0.1) + + def test_connect_timeout_wakes_on_return(self): + # Verify that a waiting thread unblocks promptly when a connection + # is returned to the pool (via the Condition variable notify). + db = FakePooledDatabase('testing', max_connections=1, timeout=5) + self.assertEqual(db.connection(), 1) + + results = [] + def try_connect(): + db._state.closed = True + try: + db.connect() + results.append(db.connection()) + except MaxConnectionsExceeded: + results.append('timeout') + + t = threading.Thread(target=try_connect) + t.start() + + # Give the thread a moment to start waiting. + time.sleep(0.05) + + # Return conn 1 to the pool — should wake the waiting thread. + db.close() + + t.join(timeout=2) + self.assertFalse(t.is_alive(), 'Thread did not wake up.') + self.assertEqual(len(results), 1) + self.assertEqual(results[0], 1) # Got the recycled connection. + + def test_connect_timeout_zero_becomes_infinite(self): + # A timeout of 0 should be treated as infinite (no immediate failure). + db = FakePooledDatabase('testing', max_connections=1, timeout=0) + self.assertEqual(db._wait_timeout, float('inf')) + + def test_close_all_wakes_waiters(self): + # Threads blocked in connect() should be woken by close_all() so they + # can create fresh connections. + db = FakePooledDatabase('testing', max_connections=1, timeout=5) + self.assertEqual(db.connection(), 1) + + results = [] + def try_connect(): + db._state.closed = True + try: + db.connect() + results.append(db.connection()) + except MaxConnectionsExceeded: + results.append('timeout') + + t = threading.Thread(target=try_connect) + t.start() + time.sleep(0.05) + + # close_all frees the slot and calls notify_all. + db.close_all() + + t.join(timeout=2) + self.assertFalse(t.is_alive(), 'Thread was not woken by close_all.') + self.assertEqual(len(results), 1) + # After close_all, the thread should have gotten a fresh connection. + self.assertEqual(results[0], 2) + + def test_close_stale_iteration(self): + db = FakePooledDatabase('testing', counter=10) + now = time.time() + for i in range(1, 11): + db._in_use[i] = PoolConnection(now - 1000, i, now - 1000) + + # All 10 should be closed. + self.assertEqual(db.close_stale(age=500), 10) + self.assertEqual(db._in_use, {}) + + def test_concurrent_close_stale_and_return(self): + # Exercise close_stale running while other threads are actively + # returning connections (calling close()). The snapshot-before-mutate + # pattern and the RLock should keep everything consistent. + db = FakePooledDatabase('testing', max_connections=20) + barrier = threading.Barrier(11) # 10 workers + main thread. + errors = [] + + def worker(n): + """Check out a connection, wait for all workers to be ready, + then return it.""" + try: + db._state.closed = True + db.connect() + barrier.wait(timeout=2) + # Small stagger so close_stale and close() overlap. + time.sleep(0.001 * (n % 3)) + db.close() + except Exception as exc: + errors.append(exc) + + # Spin up 10 threads that each grab and return a connection. + threads = [threading.Thread(target=worker, args=(i,)) + for i in range(10)] + for t in threads: t.start() + + # Wait until all threads hold a connection. + while len(db._in_use) < 10: + time.sleep(.005) + + # Artificially back-date half the checked_out times so that + # close_stale will try to close them while threads are returning. + now = time.time() + for i, key in enumerate(list(db._in_use)): + if i % 2 == 0: + pc = db._in_use[key] + db._in_use[key] = PoolConnection(pc.timestamp, pc.connection, + now - 10000) + + # Release the barrier so threads start returning connections, and + # simultaneously run close_stale from the main thread. + barrier.wait(timeout=2) + closed = db.close_stale(age=5000) + for t in threads: t.join(timeout=2) + + self.assertEqual(errors, []) + for key in db._in_use: + for _, _, conn in db._connections: + self.assertNotEqual(db.conn_key(conn), key) + + def test_manual_close_when_already_closed(self): + # manual_close on an already-closed database should return False. + self.assertFalse(self.db.manual_close()) # Never opened. + + self.db.connect() + self.db.close() + self.assertFalse(self.db.manual_close()) # Already closed. + + def test_close_idle_driver_closes_all(self): + # Every idle connection should be driver-closed. + db = FakePooledDatabase('testing', counter=5) + now = time.time() + for i in range(1, 6): + push_conn(db, now - i, i) + + closed_before = db.closed_counter + db.close_idle() + self.assertEqual(db._connections, []) + self.assertEqual(db.closed_counter, closed_before + 5) + + def test_max_connections_zero_means_unlimited(self): + # max_connections=0 (falsy) should mean no limit. + db = FakePooledDatabase('testing', max_connections=0) + for i in range(50): + db._state.closed = True + db.connect() + self.assertEqual(len(db._in_use), 50) + + def test_stale_and_closed_all_skipped(self): + # If every connection in the pool is either stale or closed, a new one + # should be created. + class AllClosedDatabase(FakePooledDatabase): + def _is_closed(self, conn): + return True + + db = AllClosedDatabase('testing', stale_timeout=10) + now = time.time() + push_conn(db, now - 20, 1) # Stale. + push_conn(db, now, 2) # Closed (per _is_closed override). + db.counter = 2 + + self.assertEqual(db.connection(), 3) + self.assertEqual(db._connections, []) + self.assertEqual(list(db._in_use.keys()), [3]) + + def test_init_updates_pool_parameters(self): + # The init() method should allow updating pool parameters after + # initial construction. + db = FakePooledDatabase('testing', max_connections=5, stale_timeout=10, + timeout=2) + self.assertEqual(db._max_connections, 5) + self.assertEqual(db._stale_timeout, 10) + self.assertEqual(db._wait_timeout, 2) + + db.init('testing', max_connections=50, stale_timeout=100, timeout=20) + self.assertEqual(db._max_connections, 50) + self.assertEqual(db._stale_timeout, 100) + self.assertEqual(db._wait_timeout, 20) + + def test_init_timeout_zero_becomes_infinite(self): + db = FakePooledDatabase('testing', timeout=5) + self.assertEqual(db._wait_timeout, 5) + + db.init('testing', timeout=0) + self.assertEqual(db._wait_timeout, float('inf')) + class TestLivePooledDatabase(ModelTestCase): database = PooledTestDatabase('test_pooled.db')