mirror of
https://github.com/coleifer/peewee.git
synced 2026-05-06 07:56:41 -04:00
Clean up waiting logic & synchronization in connection pool.
Add tests for a missing behaviors/logic.
This commit is contained in:
+96
-78
@@ -4,7 +4,6 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from itertools import chain
|
||||
|
||||
from peewee import MySQLDatabase
|
||||
from peewee import PostgresqlDatabase
|
||||
@@ -47,11 +46,18 @@ class PooledDatabase(object):
|
||||
if self._wait_timeout == 0:
|
||||
self._wait_timeout = float('inf')
|
||||
|
||||
# Lock for pool operations and condition for notifying when connection
|
||||
# is released back to pool.
|
||||
self._pool_lock = threading.RLock()
|
||||
self._pool_available = threading.Condition(self._pool_lock)
|
||||
|
||||
# Available / idle connections stored in a heap, sorted oldest first.
|
||||
self._connections = []
|
||||
|
||||
# Counter used for tie-breaker in heap (so we don't try comparing
|
||||
# connection against connection).
|
||||
self._heap_counter = 0
|
||||
|
||||
# Mapping of connection id to PoolConnection. Ordinarily we would want
|
||||
# to use something like a WeakKeyDictionary, but Python typically won't
|
||||
# allow us to create weak references to connection objects.
|
||||
@@ -81,58 +87,51 @@ class PooledDatabase(object):
|
||||
if not self._wait_timeout:
|
||||
return super(PooledDatabase, self).connect(reuse_if_open)
|
||||
|
||||
expires = time.time() + self._wait_timeout
|
||||
while expires > time.time():
|
||||
deadline = time.monotonic() + self._wait_timeout
|
||||
while True:
|
||||
try:
|
||||
ret = super(PooledDatabase, self).connect(reuse_if_open)
|
||||
return super(PooledDatabase, self).connect(reuse_if_open)
|
||||
except MaxConnectionsExceeded:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
return ret
|
||||
raise MaxConnectionsExceeded('Max connections exceeded, timed out '
|
||||
'attempting to connect.')
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
raise MaxConnectionsExceeded(
|
||||
'Max connections exceeded, timed out attempting to '
|
||||
'connect.')
|
||||
with self._pool_available:
|
||||
self._pool_available.wait(timeout=min(remaining, 1.0))
|
||||
|
||||
@locked
|
||||
def _connect(self):
|
||||
while True:
|
||||
while self._connections:
|
||||
try:
|
||||
# Remove the oldest connection from the heap.
|
||||
ts, _, c_conn = heapq.heappop(self._connections)
|
||||
conn = c_conn
|
||||
key = self.conn_key(conn)
|
||||
ts, _counter, conn = heapq.heappop(self._connections)
|
||||
except IndexError:
|
||||
ts = conn = None
|
||||
logger.debug('No connection available in pool.')
|
||||
break
|
||||
else:
|
||||
if self._is_closed(conn):
|
||||
# This connecton was closed, but since it was not stale
|
||||
# it got added back to the queue of available conns. We
|
||||
# then closed it and marked it as explicitly closed, so
|
||||
# it's safe to throw it away now.
|
||||
# (Because Database.close() calls Database._close()).
|
||||
logger.debug('Connection %s was closed.', key)
|
||||
ts = conn = None
|
||||
elif self._stale_timeout and self._is_stale(ts):
|
||||
# If we are attempting to check out a stale connection,
|
||||
# then close it. We don't need to mark it in the "closed"
|
||||
# set, because it is not in the list of available conns
|
||||
# anymore.
|
||||
logger.debug('Connection %s was stale, closing.', key)
|
||||
self._close(conn, True)
|
||||
ts = conn = None
|
||||
else:
|
||||
break
|
||||
|
||||
if conn is None:
|
||||
if self._max_connections and (
|
||||
len(self._in_use) >= self._max_connections):
|
||||
raise MaxConnectionsExceeded('Exceeded maximum connections.')
|
||||
conn = super(PooledDatabase, self)._connect()
|
||||
ts = time.time()
|
||||
key = self.conn_key(conn)
|
||||
logger.debug('Created new connection %s.', key)
|
||||
if self._is_closed(conn):
|
||||
# Connection closed either by user or by driver - discard.
|
||||
logger.debug('Connection %s was closed, discarding.', key)
|
||||
continue
|
||||
|
||||
if self._stale_timeout and self._is_stale(ts):
|
||||
logger.debug('Connection %s was stale, closing.', key)
|
||||
self._close_raw(conn)
|
||||
continue
|
||||
|
||||
# Connection OK to use.
|
||||
self._in_use[key] = PoolConnection(ts, conn, time.time())
|
||||
return conn
|
||||
|
||||
if self._max_connections and (
|
||||
len(self._in_use) >= self._max_connections):
|
||||
raise MaxConnectionsExceeded('Exceeded maximum connections.')
|
||||
|
||||
conn = super(PooledDatabase, self)._connect()
|
||||
ts = time.time()
|
||||
key = self.conn_key(conn)
|
||||
logger.debug('Created new connection %s.', key)
|
||||
self._in_use[key] = PoolConnection(ts, conn, time.time())
|
||||
return conn
|
||||
|
||||
@@ -148,22 +147,43 @@ class PooledDatabase(object):
|
||||
# Called on check-in to make sure the connection can be re-used.
|
||||
return True
|
||||
|
||||
def _close_raw(self, conn):
|
||||
try:
|
||||
super(PooledDatabase, self)._close(conn)
|
||||
except Exception:
|
||||
logger.debug('Error closing connection %s.', self.conn_key(conn),
|
||||
exc_info=True)
|
||||
|
||||
@locked
|
||||
def _close(self, conn, close_conn=False):
|
||||
# if close_conn == True, close underlying driver connection and remove
|
||||
# from _in_use tracking. Do not return to available conns.
|
||||
key = self.conn_key(conn)
|
||||
|
||||
if close_conn:
|
||||
super(PooledDatabase, self)._close(conn)
|
||||
elif key in self._in_use:
|
||||
pool_conn = self._in_use.pop(key)
|
||||
if self._stale_timeout and self._is_stale(pool_conn.timestamp):
|
||||
logger.debug('Closing stale connection %s.', key)
|
||||
super(PooledDatabase, self)._close(conn)
|
||||
elif self._can_reuse(conn):
|
||||
logger.debug('Returning %s to pool.', key)
|
||||
heapq.heappush(self._connections,
|
||||
(pool_conn.timestamp, _sentinel(), conn))
|
||||
else:
|
||||
logger.debug('Closed %s.', key)
|
||||
self._in_use.pop(key, None)
|
||||
self._close_raw(conn)
|
||||
return
|
||||
|
||||
if key not in self._in_use:
|
||||
logger.debug('Connection %s not in use, ignoring close.', key)
|
||||
return
|
||||
|
||||
pool_conn = self._in_use.pop(key)
|
||||
if self._stale_timeout and self._is_stale(pool_conn.timestamp):
|
||||
logger.debug('Closing stale connection %s on check-in.', key)
|
||||
self._close_raw(conn)
|
||||
elif not self._can_reuse(conn):
|
||||
logger.debug('Connection %s not reusable, closing.', key)
|
||||
self._close_raw(conn)
|
||||
else:
|
||||
logger.debug('Returning %s to pool.', key)
|
||||
self._heap_counter += 1
|
||||
heapq.heappush(self._connections,
|
||||
(pool_conn.timestamp, self._heap_counter, conn))
|
||||
|
||||
# Wake up thread that may be waiting on connection.
|
||||
self._pool_available.notify()
|
||||
|
||||
@locked
|
||||
def manual_close(self):
|
||||
@@ -175,36 +195,36 @@ class PooledDatabase(object):
|
||||
|
||||
# Obtain reference to the connection in-use by the calling thread.
|
||||
conn = self.connection()
|
||||
key = self.conn_key(conn)
|
||||
|
||||
# A connection will only be re-added to the available list if it is
|
||||
# marked as "in use" at the time it is closed. We will explicitly
|
||||
# remove it from the "in use" list, call "close()" for the
|
||||
# side-effects, and then explicitly close the connection.
|
||||
self._in_use.pop(self.conn_key(conn), None)
|
||||
# Remove from _in_use so that subsequent self.close() won't try to
|
||||
# restore it to the pool.
|
||||
self._in_use.pop(key, None)
|
||||
self.close()
|
||||
self._close(conn, close_conn=True)
|
||||
self._close_raw(conn)
|
||||
|
||||
@locked
|
||||
def close_idle(self):
|
||||
# Close any open connections that are not currently in-use.
|
||||
for _, _, conn in self._connections:
|
||||
self._close(conn, close_conn=True)
|
||||
idle = self._connections
|
||||
self._connections = []
|
||||
for _, _, conn in idle:
|
||||
self._close_raw(conn)
|
||||
|
||||
@locked
|
||||
def close_stale(self, age=600):
|
||||
# Close any connections that are in-use but were checked out quite some
|
||||
# time ago and can be considered stale.
|
||||
in_use = {}
|
||||
# time ago and can be considered stale. May close connections in use by
|
||||
# running threads.
|
||||
cutoff = time.time() - age
|
||||
n = 0
|
||||
for key, pool_conn in self._in_use.items():
|
||||
for key, pool_conn in list(self._in_use.items()):
|
||||
if pool_conn.checked_out < cutoff:
|
||||
self._close(pool_conn.connection, close_conn=True)
|
||||
self._close_raw(pool_conn.connection)
|
||||
del self._in_use[key]
|
||||
n += 1
|
||||
else:
|
||||
in_use[key] = pool_conn
|
||||
self._in_use = in_use
|
||||
|
||||
self._pool_available.notify_all()
|
||||
return n
|
||||
|
||||
@locked
|
||||
@@ -212,12 +232,12 @@ class PooledDatabase(object):
|
||||
# Close all connections -- available and in-use. Warning: may break any
|
||||
# active connections used by other threads.
|
||||
self.close()
|
||||
for _, _, conn in self._connections:
|
||||
self._close(conn, close_conn=True)
|
||||
for pool_conn in self._in_use.values():
|
||||
self._close(pool_conn.connection, close_conn=True)
|
||||
self._connections = []
|
||||
self._in_use = {}
|
||||
self.close_idle()
|
||||
in_use, self._in_use = self._in_use, {}
|
||||
for pool_conn in in_use.values():
|
||||
self._close_raw(pool_conn.connection)
|
||||
|
||||
self._pool_available.notify_all()
|
||||
|
||||
|
||||
class _PooledMySQLDatabase(PooledDatabase):
|
||||
@@ -230,8 +250,7 @@ class _PooledMySQLDatabase(PooledDatabase):
|
||||
conn.ping(*args)
|
||||
except:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
class PooledMySQLDatabase(_PooledMySQLDatabase, MySQLDatabase):
|
||||
pass
|
||||
@@ -256,8 +275,7 @@ class _PooledSqliteDatabase(PooledDatabase):
|
||||
conn.total_changes
|
||||
except:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return False
|
||||
|
||||
class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase):
|
||||
pass
|
||||
|
||||
+365
-20
@@ -62,6 +62,12 @@ class PooledTestDatabase(PooledDatabase, SqliteDatabase):
|
||||
pass
|
||||
|
||||
|
||||
def push_conn(db, timestamp, conn):
|
||||
# Push a connection onto the pool heap with a proper monotonic counter.
|
||||
db._heap_counter += 1
|
||||
heapq.heappush(db._connections, (timestamp, db._heap_counter, conn))
|
||||
|
||||
|
||||
class TestPooledDatabase(BaseTestCase):
|
||||
def setUp(self):
|
||||
super(TestPooledDatabase, self).setUp()
|
||||
@@ -183,9 +189,10 @@ class TestPooledDatabase(BaseTestCase):
|
||||
db = FakePooledDatabase('testing', counter=3)
|
||||
|
||||
now = time.time()
|
||||
heapq.heappush(db._connections, (now - 10, None, 3))
|
||||
heapq.heappush(db._connections, (now - 5, None, 2))
|
||||
heapq.heappush(db._connections, (now - 1, None, 1))
|
||||
now = time.time()
|
||||
push_conn(db, now - 10, 3)
|
||||
push_conn(db, now - 5, 2)
|
||||
push_conn(db, now - 1, 1)
|
||||
|
||||
self.assertEqual(db.connection(), 3)
|
||||
self.assertTrue(3 in db._in_use)
|
||||
@@ -217,9 +224,9 @@ class TestPooledDatabase(BaseTestCase):
|
||||
db = FakePooledDatabase('testing', counter=3)
|
||||
|
||||
now = time.time()
|
||||
heapq.heappush(db._connections, (now - 10, None, 3))
|
||||
heapq.heappush(db._connections, (now - 5, None, 2))
|
||||
heapq.heappush(db._connections, (now - 1, None, 1))
|
||||
push_conn(db, now - 10, 3)
|
||||
push_conn(db, now - 5, 2)
|
||||
push_conn(db, now - 1, 1)
|
||||
self.assertEqual(db.connection(), 3)
|
||||
self.assertTrue(3 in db._in_use)
|
||||
|
||||
@@ -233,18 +240,19 @@ class TestPooledDatabase(BaseTestCase):
|
||||
now = time.time()
|
||||
db = FakePooledDatabase('testing', stale_timeout=10)
|
||||
conns = [
|
||||
(now - 20, None, 1),
|
||||
(now - 15, None, 2),
|
||||
(now - 5, None, 3),
|
||||
(now, None, 4),
|
||||
(now - 20, 1),
|
||||
(now - 15, 2),
|
||||
(now - 5, 3),
|
||||
(now, 4),
|
||||
]
|
||||
for ts_conn in conns:
|
||||
heapq.heappush(db._connections, ts_conn)
|
||||
for ts, conn in conns:
|
||||
push_conn(db, ts, conn)
|
||||
|
||||
self.assertEqual(db.connection(), 3)
|
||||
self.assertEqual(len(db._in_use), 1)
|
||||
self.assertTrue(3 in db._in_use)
|
||||
self.assertEqual(db._connections, [(now, None, 4)])
|
||||
self.assertEqual(len(db._connections), 1)
|
||||
self.assertEqual(db._connections[0][2], 4)
|
||||
|
||||
def test_connect_cascade(self):
|
||||
now = time.time()
|
||||
@@ -255,14 +263,14 @@ class TestPooledDatabase(BaseTestCase):
|
||||
db = ClosedPooledDatabase('testing', stale_timeout=10)
|
||||
|
||||
conns = [
|
||||
(now - 15, None, 1), # Skipped due to being stale.
|
||||
(now - 5, None, 2), # Will appear closed.
|
||||
(now - 3, None, 3),
|
||||
(now, None, 4), # Will appear closed.
|
||||
(now - 15, 1), # Skipped due to being stale.
|
||||
(now - 5, 2), # Will appear closed.
|
||||
(now - 3, 3),
|
||||
(now, 4), # Will appear closed.
|
||||
]
|
||||
db.counter = 4 # The next connection we create will have id=5.
|
||||
for ts_conn in conns:
|
||||
heapq.heappush(db._connections, ts_conn)
|
||||
for ts, conn in conns:
|
||||
push_conn(db, ts, conn)
|
||||
|
||||
# Conn 3 is not stale or closed, so we will get it.
|
||||
self.assertEqual(db.connection(), 3)
|
||||
@@ -271,7 +279,10 @@ class TestPooledDatabase(BaseTestCase):
|
||||
pool_conn = db._in_use[3]
|
||||
self.assertEqual(pool_conn.timestamp, now - 3)
|
||||
self.assertEqual(pool_conn.connection, 3)
|
||||
self.assertEqual(db._connections, [(now, None, 4)])
|
||||
|
||||
# Only conn 4 remains in the idle pool.
|
||||
self.assertEqual(len(db._connections), 1)
|
||||
self.assertEqual(db._connections[0][2], 4)
|
||||
|
||||
# Since conn 4 is closed, we will open a new conn.
|
||||
db._state.closed = True # Pretend we're in a different thread.
|
||||
@@ -314,6 +325,340 @@ class TestPooledDatabase(BaseTestCase):
|
||||
self.assertEqual(len(self.db._connections), 5)
|
||||
self.assertEqual(len(self.db._in_use), 0)
|
||||
|
||||
def test_heap_counter_deterministic_ordering(self):
|
||||
# Verify that connections pushed with the same timestamp are returned
|
||||
# in order.
|
||||
now = time.time()
|
||||
push_conn(self.db, now, 'a')
|
||||
push_conn(self.db, now, 'b')
|
||||
push_conn(self.db, now, 'c')
|
||||
|
||||
results = []
|
||||
while self.db._connections:
|
||||
ts, counter, conn = heapq.heappop(self.db._connections)
|
||||
results.append(conn)
|
||||
self.assertEqual(results, ['a', 'b', 'c'])
|
||||
|
||||
def test_close_conn_removes_from_in_use(self):
|
||||
# _close(conn, close_conn=True) should pop the key from _in_use AND
|
||||
# close the underlying driver conn.
|
||||
self.assertEqual(self.db.connection(), 1)
|
||||
self.assertTrue(1 in self.db._in_use)
|
||||
|
||||
closed_before = self.db.closed_counter
|
||||
self.db._close(1, close_conn=True)
|
||||
|
||||
self.assertNotIn(1, self.db._in_use)
|
||||
self.assertEqual(self.db.closed_counter, closed_before + 1)
|
||||
|
||||
def test_double_close_is_noop(self):
|
||||
# Calling _close on a connection not in _in_use (and close_conn=False)
|
||||
# should be a safe no-op rather than raising or leaking.
|
||||
self.assertEqual(self.db.connection(), 1)
|
||||
self.db.close() # Returns conn 1 to the pool.
|
||||
|
||||
self.assertNotIn(1, self.db._in_use)
|
||||
closed_before = self.db.closed_counter
|
||||
# Second close should do nothing.
|
||||
self.db._close(1)
|
||||
self.assertEqual(self.db.closed_counter, closed_before)
|
||||
# Pool state unchanged.
|
||||
self.assertEqual(len(self.db._connections), 1)
|
||||
|
||||
def test_can_reuse_false_closes_connection(self):
|
||||
# When _can_reuse returns False on check-in, the connection should be
|
||||
# closed at the driver level and not returned to the pool.
|
||||
class NotReusablePooledDatabase(FakePooledDatabase):
|
||||
def _can_reuse(self, conn):
|
||||
return False
|
||||
|
||||
db = NotReusablePooledDatabase('testing')
|
||||
self.assertEqual(db.connection(), 1)
|
||||
closed_before = db.closed_counter
|
||||
|
||||
db.close()
|
||||
|
||||
# Connection should have been driver-closed, not pooled.
|
||||
self.assertEqual(db.closed_counter, closed_before + 1)
|
||||
self.assertEqual(len(db._connections), 0)
|
||||
self.assertEqual(db._in_use, {})
|
||||
|
||||
# Next connect creates a brand new connection.
|
||||
self.assertEqual(db.connection(), 2)
|
||||
|
||||
def test_close_raw_swallows_exception(self):
|
||||
called = []
|
||||
# _close_raw should not propagate exceptions from the driver.
|
||||
class BrokenDriverClose(FakeDatabase):
|
||||
def _close(self, conn):
|
||||
called.append(conn)
|
||||
raise RuntimeError('failed')
|
||||
|
||||
class BrokenPool(FakePooledDatabase, BrokenDriverClose):
|
||||
pass
|
||||
|
||||
db = BrokenPool('testing')
|
||||
db._close_raw(1337)
|
||||
self.assertEqual(called, [1337])
|
||||
|
||||
def test_close_stale_removes_from_in_use(self):
|
||||
# Verify that close_stale both driver-closes the connection AND
|
||||
# removes it from _in_use (no dangling keys).
|
||||
db = FakePooledDatabase('testing', counter=2)
|
||||
|
||||
now = time.time()
|
||||
db._in_use[1] = PoolConnection(now - 1000, 1, now - 1000)
|
||||
db._in_use[2] = PoolConnection(now, 2, now)
|
||||
|
||||
closed_before = db.closed_counter
|
||||
self.assertEqual(db.close_stale(age=500), 1)
|
||||
self.assertNotIn(1, db._in_use)
|
||||
self.assertIn(2, db._in_use)
|
||||
self.assertEqual(db.closed_counter, closed_before + 1)
|
||||
|
||||
def test_close_all_clears_both_pools(self):
|
||||
# close_all should leave both _connections and _in_use completely
|
||||
# empty, and driver-close every connection.
|
||||
db = FakePooledDatabase('testing', counter=3)
|
||||
|
||||
now = time.time()
|
||||
push_conn(db, now - 5, 1)
|
||||
push_conn(db, now - 1, 2)
|
||||
|
||||
# Simulate two in-use connections.
|
||||
db._in_use[3] = PoolConnection(now, 3, now)
|
||||
db._in_use[4] = PoolConnection(now, 4, now)
|
||||
|
||||
# One more for the "current thread" via normal connect path so
|
||||
# self.close() inside close_all has something to reset.
|
||||
db._state.closed = True
|
||||
db.connect()
|
||||
conn = db.connection()
|
||||
self.assertIn(db.conn_key(conn), db._in_use)
|
||||
|
||||
closed_before = db.closed_counter
|
||||
db.close_all()
|
||||
|
||||
self.assertEqual(db._connections, [])
|
||||
self.assertEqual(db._in_use, {})
|
||||
# 2 idle + 2 manually-added in_use + the current thread's conn = 5.
|
||||
# (close_all calls self.close() which triggers _close for the current
|
||||
# thread's conn, but that goes through the return-to-pool path, not
|
||||
# _close_raw. The subsequent loop over the snapshot handles it.)
|
||||
self.assertGreaterEqual(db.closed_counter, closed_before + 4)
|
||||
|
||||
def test_connect_timeout_with_condition_variable(self):
|
||||
# Verify that connect() with a timeout raises after the timeout
|
||||
# expires when the pool is exhausted.
|
||||
db = FakePooledDatabase('testing', max_connections=1, timeout=0.15)
|
||||
self.assertEqual(db.connection(), 1)
|
||||
|
||||
errors = []
|
||||
def try_connect():
|
||||
db._state.closed = True # Appear as a new thread.
|
||||
try:
|
||||
db.connect()
|
||||
except MaxConnectionsExceeded:
|
||||
errors.append(True)
|
||||
|
||||
t = threading.Thread(target=try_connect)
|
||||
start = time.monotonic()
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Should have waited roughly the timeout duration.
|
||||
self.assertEqual(len(errors), 1)
|
||||
self.assertGreaterEqual(elapsed, 0.1)
|
||||
|
||||
def test_connect_timeout_wakes_on_return(self):
|
||||
# Verify that a waiting thread unblocks promptly when a connection
|
||||
# is returned to the pool (via the Condition variable notify).
|
||||
db = FakePooledDatabase('testing', max_connections=1, timeout=5)
|
||||
self.assertEqual(db.connection(), 1)
|
||||
|
||||
results = []
|
||||
def try_connect():
|
||||
db._state.closed = True
|
||||
try:
|
||||
db.connect()
|
||||
results.append(db.connection())
|
||||
except MaxConnectionsExceeded:
|
||||
results.append('timeout')
|
||||
|
||||
t = threading.Thread(target=try_connect)
|
||||
t.start()
|
||||
|
||||
# Give the thread a moment to start waiting.
|
||||
time.sleep(0.05)
|
||||
|
||||
# Return conn 1 to the pool — should wake the waiting thread.
|
||||
db.close()
|
||||
|
||||
t.join(timeout=2)
|
||||
self.assertFalse(t.is_alive(), 'Thread did not wake up.')
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0], 1) # Got the recycled connection.
|
||||
|
||||
def test_connect_timeout_zero_becomes_infinite(self):
|
||||
# A timeout of 0 should be treated as infinite (no immediate failure).
|
||||
db = FakePooledDatabase('testing', max_connections=1, timeout=0)
|
||||
self.assertEqual(db._wait_timeout, float('inf'))
|
||||
|
||||
def test_close_all_wakes_waiters(self):
|
||||
# Threads blocked in connect() should be woken by close_all() so they
|
||||
# can create fresh connections.
|
||||
db = FakePooledDatabase('testing', max_connections=1, timeout=5)
|
||||
self.assertEqual(db.connection(), 1)
|
||||
|
||||
results = []
|
||||
def try_connect():
|
||||
db._state.closed = True
|
||||
try:
|
||||
db.connect()
|
||||
results.append(db.connection())
|
||||
except MaxConnectionsExceeded:
|
||||
results.append('timeout')
|
||||
|
||||
t = threading.Thread(target=try_connect)
|
||||
t.start()
|
||||
time.sleep(0.05)
|
||||
|
||||
# close_all frees the slot and calls notify_all.
|
||||
db.close_all()
|
||||
|
||||
t.join(timeout=2)
|
||||
self.assertFalse(t.is_alive(), 'Thread was not woken by close_all.')
|
||||
self.assertEqual(len(results), 1)
|
||||
# After close_all, the thread should have gotten a fresh connection.
|
||||
self.assertEqual(results[0], 2)
|
||||
|
||||
def test_close_stale_iteration(self):
|
||||
db = FakePooledDatabase('testing', counter=10)
|
||||
now = time.time()
|
||||
for i in range(1, 11):
|
||||
db._in_use[i] = PoolConnection(now - 1000, i, now - 1000)
|
||||
|
||||
# All 10 should be closed.
|
||||
self.assertEqual(db.close_stale(age=500), 10)
|
||||
self.assertEqual(db._in_use, {})
|
||||
|
||||
def test_concurrent_close_stale_and_return(self):
|
||||
# Exercise close_stale running while other threads are actively
|
||||
# returning connections (calling close()). The snapshot-before-mutate
|
||||
# pattern and the RLock should keep everything consistent.
|
||||
db = FakePooledDatabase('testing', max_connections=20)
|
||||
barrier = threading.Barrier(11) # 10 workers + main thread.
|
||||
errors = []
|
||||
|
||||
def worker(n):
|
||||
"""Check out a connection, wait for all workers to be ready,
|
||||
then return it."""
|
||||
try:
|
||||
db._state.closed = True
|
||||
db.connect()
|
||||
barrier.wait(timeout=2)
|
||||
# Small stagger so close_stale and close() overlap.
|
||||
time.sleep(0.001 * (n % 3))
|
||||
db.close()
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
|
||||
# Spin up 10 threads that each grab and return a connection.
|
||||
threads = [threading.Thread(target=worker, args=(i,))
|
||||
for i in range(10)]
|
||||
for t in threads: t.start()
|
||||
|
||||
# Wait until all threads hold a connection.
|
||||
while len(db._in_use) < 10:
|
||||
time.sleep(.005)
|
||||
|
||||
# Artificially back-date half the checked_out times so that
|
||||
# close_stale will try to close them while threads are returning.
|
||||
now = time.time()
|
||||
for i, key in enumerate(list(db._in_use)):
|
||||
if i % 2 == 0:
|
||||
pc = db._in_use[key]
|
||||
db._in_use[key] = PoolConnection(pc.timestamp, pc.connection,
|
||||
now - 10000)
|
||||
|
||||
# Release the barrier so threads start returning connections, and
|
||||
# simultaneously run close_stale from the main thread.
|
||||
barrier.wait(timeout=2)
|
||||
closed = db.close_stale(age=5000)
|
||||
for t in threads: t.join(timeout=2)
|
||||
|
||||
self.assertEqual(errors, [])
|
||||
for key in db._in_use:
|
||||
for _, _, conn in db._connections:
|
||||
self.assertNotEqual(db.conn_key(conn), key)
|
||||
|
||||
def test_manual_close_when_already_closed(self):
|
||||
# manual_close on an already-closed database should return False.
|
||||
self.assertFalse(self.db.manual_close()) # Never opened.
|
||||
|
||||
self.db.connect()
|
||||
self.db.close()
|
||||
self.assertFalse(self.db.manual_close()) # Already closed.
|
||||
|
||||
def test_close_idle_driver_closes_all(self):
|
||||
# Every idle connection should be driver-closed.
|
||||
db = FakePooledDatabase('testing', counter=5)
|
||||
now = time.time()
|
||||
for i in range(1, 6):
|
||||
push_conn(db, now - i, i)
|
||||
|
||||
closed_before = db.closed_counter
|
||||
db.close_idle()
|
||||
self.assertEqual(db._connections, [])
|
||||
self.assertEqual(db.closed_counter, closed_before + 5)
|
||||
|
||||
def test_max_connections_zero_means_unlimited(self):
|
||||
# max_connections=0 (falsy) should mean no limit.
|
||||
db = FakePooledDatabase('testing', max_connections=0)
|
||||
for i in range(50):
|
||||
db._state.closed = True
|
||||
db.connect()
|
||||
self.assertEqual(len(db._in_use), 50)
|
||||
|
||||
def test_stale_and_closed_all_skipped(self):
|
||||
# If every connection in the pool is either stale or closed, a new one
|
||||
# should be created.
|
||||
class AllClosedDatabase(FakePooledDatabase):
|
||||
def _is_closed(self, conn):
|
||||
return True
|
||||
|
||||
db = AllClosedDatabase('testing', stale_timeout=10)
|
||||
now = time.time()
|
||||
push_conn(db, now - 20, 1) # Stale.
|
||||
push_conn(db, now, 2) # Closed (per _is_closed override).
|
||||
db.counter = 2
|
||||
|
||||
self.assertEqual(db.connection(), 3)
|
||||
self.assertEqual(db._connections, [])
|
||||
self.assertEqual(list(db._in_use.keys()), [3])
|
||||
|
||||
def test_init_updates_pool_parameters(self):
|
||||
# The init() method should allow updating pool parameters after
|
||||
# initial construction.
|
||||
db = FakePooledDatabase('testing', max_connections=5, stale_timeout=10,
|
||||
timeout=2)
|
||||
self.assertEqual(db._max_connections, 5)
|
||||
self.assertEqual(db._stale_timeout, 10)
|
||||
self.assertEqual(db._wait_timeout, 2)
|
||||
|
||||
db.init('testing', max_connections=50, stale_timeout=100, timeout=20)
|
||||
self.assertEqual(db._max_connections, 50)
|
||||
self.assertEqual(db._stale_timeout, 100)
|
||||
self.assertEqual(db._wait_timeout, 20)
|
||||
|
||||
def test_init_timeout_zero_becomes_infinite(self):
|
||||
db = FakePooledDatabase('testing', timeout=5)
|
||||
self.assertEqual(db._wait_timeout, 5)
|
||||
|
||||
db.init('testing', timeout=0)
|
||||
self.assertEqual(db._wait_timeout, float('inf'))
|
||||
|
||||
|
||||
class TestLivePooledDatabase(ModelTestCase):
|
||||
database = PooledTestDatabase('test_pooled.db')
|
||||
|
||||
Reference in New Issue
Block a user