Skip to content

Commit

Permalink
Merge pull request #435 from Pylons/bugfix/remove-race-condition
Browse files Browse the repository at this point in the history
Remove race condition when creating new HTTPChannel
  • Loading branch information
digitalresistor authored Jun 8, 2024
2 parents 8565e0d + 9d99c89 commit 1ae4e89
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 159 deletions.
9 changes: 1 addition & 8 deletions src/waitress/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def __init__(self, server, sock, addr, adj, map=None):
self.outbuf_lock = threading.Condition()

wasyncore.dispatcher.__init__(self, sock, map=map)

# Don't let wasyncore.dispatcher throttle self.addr on us.
self.connected = True
self.addr = addr
self.requests = []

Expand All @@ -92,13 +91,7 @@ def handle_write(self):
# Precondition: there's data in the out buffer to be sent, or
# there's a pending will_close request

if not self.connected:
# we dont want to close the channel twice

return

# try to flush any pending output

if not self.requests:
# 1. There are no running tasks, so we don't need to try to lock
# the outbuf before sending
Expand Down
69 changes: 7 additions & 62 deletions src/waitress/wasyncore.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,22 +297,6 @@ def __init__(self, sock=None, map=None):
# get a socket from a blocking source.
sock.setblocking(0)
self.set_socket(sock, map)
self.connected = True
# The constructor no longer requires that the socket
# passed be connected.
try:
self.addr = sock.getpeername()
except OSError as err:
if err.args[0] in (ENOTCONN, EINVAL):
# To handle the case where we got an unconnected
# socket.
self.connected = False
else:
# The socket is broken in some unknown way, alert
# the user and remove it from the map (to prevent
# polling of broken sockets).
self.del_channel(map)
raise
else:
self.socket = None

Expand Down Expand Up @@ -394,23 +378,6 @@ def bind(self, addr):
self.addr = addr
return self.socket.bind(addr)

def connect(self, address):
self.connected = False
self.connecting = True
err = self.socket.connect_ex(address)
if (
err in (EINPROGRESS, EALREADY, EWOULDBLOCK)
or err == EINVAL
and os.name == "nt"
): # pragma: no cover
self.addr = address
return
if err in (0, EISCONN):
self.addr = address
self.handle_connect_event()
else:
raise OSError(err, errorcode[err])

def accept(self):
# XXX can return either an address pair or None
try:
Expand Down Expand Up @@ -469,6 +436,8 @@ def close(self):
if why.args[0] not in (ENOTCONN, EBADF):
raise

self.socket = None

# log and log_info may be overridden to provide more sophisticated
# logging and warning methods. In general, log is for 'hit' logging
# and 'log_info' is for informational, warning and error logging.
Expand Down Expand Up @@ -519,7 +488,11 @@ def handle_expt_event(self):
# handle_expt_event() is called if there might be an error on the
# socket, or if there is OOB data
# check for the error condition first
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
err = (
self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if self.socket is not None
else 1
)
if err != 0:
# we can get here when select.select() says that there is an
# exceptional condition on the socket
Expand Down Expand Up @@ -572,34 +545,6 @@ def handle_close(self):
self.close()


# ---------------------------------------------------------------------------
# adds simple buffered output capability, useful for simple clients.
# [for more sophisticated usage use asynchat.async_chat]
# ---------------------------------------------------------------------------


class dispatcher_with_send(dispatcher):
def __init__(self, sock=None, map=None):
dispatcher.__init__(self, sock, map)
self.out_buffer = b""

def initiate_send(self):
num_sent = 0
num_sent = dispatcher.send(self, self.out_buffer[:65536])
self.out_buffer = self.out_buffer[num_sent:]

handle_write = initiate_send

def writable(self):
return (not self.connected) or len(self.out_buffer)

def send(self, data):
if self.debug: # pragma: no cover
self.log_info("sending %s" % repr(data))
self.out_buffer = self.out_buffer + data
self.initiate_send()


def close_all(map=None, ignore_all=False):
if map is None: # pragma: no cover
map = socket_map
Expand Down
107 changes: 18 additions & 89 deletions tests/test_wasyncore.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import _thread as thread
import contextlib
import errno
from errno import EALREADY, EINPROGRESS, EINVAL, EISCONN, EWOULDBLOCK, errorcode
import functools
import gc
from io import BytesIO
Expand Down Expand Up @@ -641,62 +642,6 @@ def test_strerror(self):
self.assertTrue(err != "")


class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover
def readable(self):
return False

def handle_connect(self):
pass


class DispatcherWithSendTests(unittest.TestCase):
def setUp(self):
pass

def tearDown(self):
asyncore.close_all()

@reap_threads
def test_send(self):
evt = threading.Event()
sock = socket.socket()
sock.settimeout(3)
port = bind_port(sock)

cap = BytesIO()
args = (evt, cap, sock)
t = threading.Thread(target=capture_server, args=args)
t.start()
try:
# wait a little longer for the server to initialize (it sometimes
# refuses connections on slow machines without this wait)
time.sleep(0.2)

data = b"Suppose there isn't a 16-ton weight?"
d = dispatcherwithsend_noread()
d.create_socket()
d.connect((HOST, port))

# give time for socket to connect
time.sleep(0.1)

d.send(data)
d.send(data)
d.send(b"\n")

n = 1000

while d.out_buffer and n > 0: # pragma: no cover
asyncore.poll()
n -= 1

evt.wait()

self.assertEqual(cap.getvalue(), data * 2)
finally:
join_thread(t, timeout=TIMEOUT)


@unittest.skipUnless(
hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required"
)
Expand Down Expand Up @@ -839,6 +784,23 @@ def __init__(self, family, address):
self.create_socket(family)
self.connect(address)

def connect(self, address):
self.connected = False
self.connecting = True
err = self.socket.connect_ex(address)
if (
err in (EINPROGRESS, EALREADY, EWOULDBLOCK)
or err == EINVAL
and os.name == "nt"
): # pragma: no cover
self.addr = address
return
if err in (0, EISCONN):
self.addr = address
self.handle_connect_event()
else:
raise OSError(err, errorcode[err])

def handle_connect(self):
pass

Expand Down Expand Up @@ -1454,17 +1416,6 @@ def _makeOne(self, sock=None, map=None):

return dispatcher(sock=sock, map=map)

def test_unexpected_getpeername_exc(self):
sock = dummysocket()

def getpeername():
raise OSError(errno.EBADF)

map = {}
sock.getpeername = getpeername
self.assertRaises(socket.error, self._makeOne, sock=sock, map=map)
self.assertEqual(map, {})

def test___repr__accepting(self):
sock = dummysocket()
map = {}
Expand Down Expand Up @@ -1500,13 +1451,6 @@ def setsockopt(*arg, **kw):
inst.set_reuse_addr()
self.assertTrue(sock.errored)

def test_connect_raise_socket_error(self):
sock = dummysocket()
map = {}
sock.connect_ex = lambda *arg: 1
inst = self._makeOne(sock=sock, map=map)
self.assertRaises(socket.error, inst.connect, 0)

def test_accept_raise_TypeError(self):
sock = dummysocket()
map = {}
Expand Down Expand Up @@ -1675,21 +1619,6 @@ def test_handle_accepted(self):
self.assertTrue(sock.closed)


class Test_dispatcher_with_send(unittest.TestCase):
def _makeOne(self, sock=None, map=None):
from waitress.wasyncore import dispatcher_with_send

return dispatcher_with_send(sock=sock, map=map)

def test_writable(self):
sock = dummysocket()
map = {}
inst = self._makeOne(sock=sock, map=map)
inst.out_buffer = b"123"
inst.connected = True
self.assertTrue(inst.writable())


class Test_close_all(unittest.TestCase):
def _callFUT(self, map=None, ignore_all=False):
from waitress.wasyncore import close_all
Expand Down

0 comments on commit 1ae4e89

Please sign in to comment.