Created
February 1, 2025 18:16
-
-
Save txomon/ce8bc4959461f64102403bdc96571760 to your computer and use it in GitHub Desktop.
Detect client-side disconnection of socket to raise an exception on the server to interrupt traffic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ctypes | |
import enum | |
import logging | |
import os | |
import socket | |
import struct | |
import sys | |
try: | |
import tenacity.retry | |
except ImportError: | |
tenacity = None | |
import threading | |
import time | |
import typing | |
import weakref | |
import django.http as http | |
from django.utils.deprecation import MiddlewareMixin | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Support for wsgiref is always there | |
CLIENT_DISCONNECT_EXCEPTIONS = [ConnectionAbortedError] | |
# Support for gunicorn | |
try: | |
import gunicorn.http.errors as gunicorn_errors | |
CLIENT_DISCONNECT_EXCEPTIONS.append(gunicorn_errors.NoMoreData) | |
except ImportError: | |
pass | |
class ClientClosedConnectionError(*CLIENT_DISCONNECT_EXCEPTIONS): | |
"""This class needs to be a subclass of all the Errors and Exceptions that signify a | |
client closing the connection in the web worker managers. | |
Right now it works for: | |
Gunicorn: https://github.com/benoitc/gunicorn/blob/ab9c8301cb9ae573ba597154ddeea16f0326fc15/gunicorn/workers/gthread.py#L285 | |
wsgiref (AKA django's manage.py runserver): https://github.com/python/cpython/blob/16c9415fba4972743f1944ebc44946e475e68bc4/Lib/wsgiref/handlers.py#L139 | |
""" | |
pass | |
# tcpi_state values taken from https://github.com/torvalds/linux/blob/master/include/net/tcp_states.h | |
# explanations https://github.com/torvalds/linux/blob/52a93d39b17dc7eb98b6aa3edb93943248e03b2f/net/ipv4/tcp.c#L209-L241 | |
class TcpiState(enum.IntEnum): | |
TCP_ESTABLISHED = 1 | |
TCP_SYN_SENT = enum.auto() | |
TCP_SYN_RECV = enum.auto() | |
TCP_FIN_WAIT1 = enum.auto() | |
TCP_FIN_WAIT2 = enum.auto() | |
TCP_TIME_WAIT = enum.auto() | |
TCP_CLOSE = enum.auto() | |
TCP_CLOSE_WAIT = enum.auto() | |
TCP_LAST_ACK = enum.auto() | |
TCP_LISTEN = enum.auto() | |
TCP_CLOSING = enum.auto() | |
TCP_NEW_SYN_RECV = enum.auto() | |
TCP_MAX_STATES = enum.auto() | |
# Precompute the set of states that are closing states for us | |
CLOSING_STATES = ( | |
TcpiState.TCP_FIN_WAIT1, | |
TcpiState.TCP_FIN_WAIT2, | |
TcpiState.TCP_TIME_WAIT, | |
TcpiState.TCP_CLOSE, | |
TcpiState.TCP_CLOSE_WAIT, | |
TcpiState.TCP_LAST_ACK, | |
TcpiState.TCP_CLOSING, | |
) | |
def is_closed(sock): | |
""" | |
Checks through the TCP layer socket whether the socket is still active | |
or not. It's reliable however it's extremely slow (0.5s~) | |
""" | |
# If the socket is marked as closed, don't even bother | |
if getattr(sock, "_closed", False): | |
return True | |
try: | |
# We don't need the full struct tcp_info buffer because tcpi_state | |
# is the first attribute in the struct, and occupies a single byte | |
tcpi_state = sock.getsockopt(socket.SOL_TCP, socket.TCP_INFO) | |
# For whatever reason, the system call might fail, we will assume this means | |
# the socket is closed, given that it's usually due to some abnormal state | |
except OSError: | |
return True | |
return tcpi_state in CLOSING_STATES | |
def raise_client_disconnect_error_in_thread(thread: threading.Thread): | |
# Ideas taken from: | |
# https://stackoverflow.com/questions/36484151/throw-an-exception-into-another-thread | |
# https://gist.github.com/liuw/2407154 | |
# We are all aware it's a hideous hack, but there are no contexts | |
# like in golang, so... | |
ret = ctypes.pythonapi.PyThreadState_SetAsyncExc( | |
ctypes.c_long(thread.ident), | |
ctypes.py_object(ClientClosedConnectionError), | |
) | |
if ret == 0: | |
raise ValueError(f"Invalid thread identifier {thread.ident}") | |
elif ret == 1: | |
pass | |
elif ret > 1: | |
# If we failed to SetAsyncExc, just in case, we set it back to none | |
# so that weird stuff might not happen, not sure exactly why | |
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread.ident), 0) | |
def after_retry_logger(retry_state): | |
logger.critical( | |
f"Library error, please report bug. Retrying {retry_state.fn} " | |
f"attempt {retry_state.attempt_number} ended with: {retry_state.outcome}", | |
exc_info=retry_state.outcome.exception(), | |
) | |
if tenacity: | |
class retry_exceptions(tenacity.retry_base): | |
"""Retry strategy that only rejects exceptions.""" | |
def __call__(self, retry_state: "tenacity.RetryCallState") -> bool: | |
if retry_state.outcome is None: | |
raise RuntimeError("__call__() called before outcome was set") | |
return bool(retry_state.outcome.failed) | |
class ClientConnectionMonitor(threading.Thread): | |
def __init__(self, *a, **kw): | |
# Because is_closed() is so slow, we can't have a global lock, as | |
# it would add (num_sockets * 0.5s)/2 delay in average to registering | |
# new sockets. We use registration_lock to make sure that registering | |
# new thread_locks is thread safe. We need to make sure however that | |
# reading thread_locks is never done directly (but by copy) because | |
# thread_locks is a WeakKeyDictionary that might be purged in the middle | |
# of iteration | |
# The block that registration_lock is used on should only be for writing | |
# or copying registration_lock, and never anything else | |
self.registration_lock = threading.Lock() | |
# With the registration lock we create a lock per thread. And that lock is | |
# saved in thread_locks so that socket register / deregister happens | |
# synchronously. This will allow us to have a synchronous block per thread. | |
# However, this block can't be used for an is_closed() because it will | |
# damage performance. | |
# | |
# Example: If we use this lock to de/register a socket on the side | |
# of the web worker thread, and for a is_closed() + raise_exc() call, | |
# the average time it will take to acquire the lock is (0.5s/2)/num_socks | |
# being 0.5s the worst case scenario, which is terribly bad. Therefore, | |
# we need to be careful and make sure that we never call is_closed() | |
# from withing a thread_locks protected block. | |
self.thread_locks: typing.Dict[ | |
threading.Thread, threading.Lock | |
] = weakref.WeakKeyDictionary() | |
# We protect thread_watchers with the above thread_locks. We need to | |
# record all the sockets we are watching, max one per thread. Removal of | |
# a socket from thread_watchers can happen through weak key recycling or | |
# by unregistering them manually | |
self.thread_watchers: typing.Dict[ | |
threading.Thread, | |
socket.socket, | |
] = weakref.WeakKeyDictionary() | |
# We only want to raise an exception once per socket, hence we need to | |
# keep track which socket was the ones we raised an exception at. | |
self.sockets_raised = weakref.WeakSet() | |
kw.setdefault("name", "ClientConnectionMonitor") | |
super().__init__(*a, **kw) | |
def _run(self): | |
while not self._is_stopped: | |
with self.registration_lock: | |
# Registration lock is necessary to add thread_watchers, | |
# however because it's a weakkeyref, it might be deleted | |
# asynchronously, so always make a copy. Also, StackOverflow lies: | |
# https://stackoverflow.com/questions/12428026/safely-iterating-over-weakkeydictionary-and-weakvaluedictionary | |
# Making this copy will ensure hard references to the threads exist | |
thread_watchers = dict(self.thread_watchers) | |
for thread, sock in thread_watchers.items(): | |
# We don't need a lock for sockets_raised because it's a thread | |
# local resource, and is_closed() needs to be run outside of the | |
# thread_lock | |
if sock in self.sockets_raised or not is_closed(sock): | |
continue | |
thread_lock = self.thread_locks[thread] | |
# We can't call is_closed() asynchronously because it will block | |
# the request flow. However, we do need to make sure that there is | |
# a before and after the socket deregistration, so that we don't | |
# raise a ClientClosedConnectionError after the views have been | |
# called. | |
# Even if we did, it wouldn't be a huge problem because the exception | |
# we raise is one that is handled, but it would break stuff like | |
# tracing if any, so it's better to just raise the exception | |
# a single time | |
with thread_lock: | |
registered_sock = self.thread_watchers.get(thread, None) | |
if registered_sock == sock: | |
raise_client_disconnect_error_in_thread(thread) | |
self.sockets_raised.add(sock) | |
time.sleep(0.2) | |
logger.error("Stopping work on thread shutdown") | |
if tenacity: | |
@tenacity.retry( | |
wait=tenacity.wait_fixed(1), | |
retry=retry_exceptions, | |
stop=tenacity.stop_never, | |
after=after_retry_logger, | |
) | |
def run(self): | |
self._run() | |
else: | |
def run(self): | |
self._run() | |
def register_socket(self, sock: socket.socket, thread: threading.Thread = None): | |
logger.debug(f"Register {sock=}") | |
thread = thread or threading.current_thread() | |
thread_lock = self.thread_locks.get(thread) | |
if not thread_lock: | |
with self.registration_lock: | |
thread_lock = self.thread_locks.get(thread) | |
if not thread_lock: | |
thread_lock = self.thread_locks[thread] = threading.Lock() | |
with thread_lock: | |
self.thread_watchers[thread] = sock | |
def deregister_socket(self, sock, thread=None): | |
logger.debug(f"Deregister {sock=}") | |
thread = thread or threading.current_thread() | |
thread_lock = self.thread_locks.get(thread) | |
if not thread_lock: | |
raise RuntimeError("Trying to unregister an already unregistered socket") | |
with thread_lock: | |
existing_socket = self.thread_watchers.pop(thread) | |
if sock != existing_socket: | |
logger.error( | |
f"Removing different socket to registered {sock=} VS {existing_socket=}" | |
) | |
CONNECTION_MONITOR = {} | |
def get_client_connection_monitor(): | |
global CONNECTION_MONITOR | |
connection_monitor = CONNECTION_MONITOR.get(os.getpid()) | |
if connection_monitor: | |
return connection_monitor | |
logger.debug("Initializing worker thread") | |
connection_monitor = ClientConnectionMonitor() | |
# With this implementation we don't leak memory | |
CONNECTION_MONITOR = {os.getpid(): connection_monitor} | |
connection_monitor.daemon = True | |
connection_monitor.start() | |
return connection_monitor | |
def get_socket(request): | |
# If using gunicorn | |
if "gunicorn.socket" in request.environ: | |
return request.environ["gunicorn.socket"] | |
# If using runserver from django | |
f = sys._getframe().f_back | |
for _ in range(500): | |
for k, v in f.f_locals.items(): | |
if isinstance(v, socket.socket): | |
return v | |
f = f.f_back | |
if not f: | |
raise ValueError("Reached the end with no eligible frame") | |
else: | |
raise ValueError("No frame holding the socket is available") | |
class SocketMonitorMiddleware(MiddlewareMixin): | |
@staticmethod | |
def process_request(request): | |
sock = get_socket(request) | |
connection_monitor = get_client_connection_monitor() | |
connection_monitor.register_socket(sock) | |
@staticmethod | |
def process_response(request, response): | |
try: | |
sock = get_socket(request) | |
connection_monitor = get_client_connection_monitor() | |
connection_monitor.deregister_socket(sock) | |
except ConnectionAbortedError: | |
pass | |
return response | |
@staticmethod | |
def process_exception(request, exception): | |
if isinstance(exception, ConnectionAbortedError): | |
return http.HttpResponse( | |
status=499, | |
reason="Client closed connection", | |
) | |
# Debugging utils | |
# Struct extracted from https://github.com/torvalds/linux/blob/master/include/uapi/linux/tcp.h#L214 | |
TCP_INFO_STRUCT = ( | |
("B", "tcpi_state"), | |
("B", "tcpi_ca_state"), | |
("B", "tcpi_retransmits"), | |
("B", "tcpi_probes"), | |
("B", "tcpi_backoff"), | |
("B", "tcpi_options"), | |
("B", "tcpi_features1"), # ("B", "tcpi_snd_wscale : 4, tcpi_rcv_wscale : 4"), | |
( | |
"B", | |
"tcpi_features2", | |
), # ("B", "tcpi_delivery_rate_app_limited:1, tcpi_fastopen_client_fail:2"), | |
("I", "tcpi_rto"), | |
("I", "tcpi_ato"), | |
("I", "tcpi_snd_mss"), | |
("I", "tcpi_rcv_mss"), | |
("I", "tcpi_unacked"), | |
("I", "tcpi_sacked"), | |
("I", "tcpi_lost"), | |
("I", "tcpi_retrans"), | |
("I", "tcpi_fackets"), | |
("I", "tcpi_last_data_sent"), | |
("I", "tcpi_last_ack_sent"), | |
("I", "tcpi_last_data_recv"), | |
("I", "tcpi_last_ack_recv"), | |
("I", "tcpi_pmtu"), | |
("I", "tcpi_rcv_ssthresh"), | |
("I", "tcpi_rtt"), | |
("I", "tcpi_rttvar"), | |
("I", "tcpi_snd_ssthresh"), | |
("I", "tcpi_snd_cwnd"), | |
("I", "tcpi_advmss"), | |
("I", "tcpi_reordering"), | |
("I", "tcpi_rcv_rtt"), | |
("I", "tcpi_rcv_space"), | |
("I", "tcpi_total_retrans"), | |
("Q", "tcpi_pacing_rate"), | |
("Q", "tcpi_max_pacing_rate"), | |
("Q", "tcpi_bytes_acked"), | |
("Q", "tcpi_bytes_received"), | |
("I", "tcpi_segs_out"), | |
("I", "tcpi_segs_in"), | |
("I", "tcpi_notsent_bytes"), | |
("I", "tcpi_min_rtt"), | |
("I", "tcpi_data_segs_in"), | |
("I", "tcpi_data_segs_out"), | |
("Q", "tcpi_delivery_rate"), | |
("Q", "tcpi_busy_time"), | |
("Q", "tcpi_rwnd_limited"), | |
("Q", "tcpi_sndbuf_limited"), | |
("I", "tcpi_delivered"), | |
("I", "tcpi_delivered_ce"), | |
("Q", "tcpi_bytes_sent"), | |
("Q", "tcpi_bytes_retrans"), | |
("I", "tcpi_dsack_dups"), | |
("I", "tcpi_reord_seen"), | |
("I", "tcpi_rcv_ooopack"), | |
("I", "tcpi_snd_wnd"), | |
("I", "tcpi_rcv_wnd"), | |
("I", "tcpi_rehash"), | |
) | |
TCP_INFO_STRUCT_FORMAT_STRING = "@" + "".join(x for x, _ in TCP_INFO_STRUCT) | |
TCP_INFO_STRUCT_SIZE = struct.calcsize(TCP_INFO_STRUCT_FORMAT_STRING) | |
def get_tcp_info(sock): | |
tcp_info_struct = sock.getsockopt( | |
socket.SOL_TCP, socket.TCP_INFO, TCP_INFO_STRUCT_SIZE | |
) | |
tcp_info_struct_values = struct.unpack( | |
TCP_INFO_STRUCT_FORMAT_STRING, tcp_info_struct | |
) | |
tcp_info = {} | |
for value, (_, name) in zip(tcp_info_struct_values, TCP_INFO_STRUCT): | |
tcp_info[name] = value | |
return tcp_info | |
def print_tcp_info(sock): | |
tcp_info = get_tcp_info(sock) | |
print(f"{tcp_info=}") | |
OPERATIONS = ( | |
(socket.SOL_SOCKET, socket.SO_ERROR, "SO_ERROR"), | |
(socket.SOL_SOCKET, socket.SOMAXCONN, "SOMAXCONN"), | |
(socket.SOL_SOCKET, socket.SO_ACCEPTCONN, "SO_ACCEPTCONN"), | |
(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, "SO_BINDTODEVICE"), | |
(socket.SOL_SOCKET, socket.SO_BROADCAST, "SO_BROADCAST"), | |
(socket.SOL_SOCKET, socket.SO_DEBUG, "SO_DEBUG"), | |
(socket.SOL_SOCKET, socket.SO_DOMAIN, "SO_DOMAIN"), | |
(socket.SOL_SOCKET, socket.SO_DONTROUTE, "SO_DONTROUTE"), | |
(socket.SOL_SOCKET, socket.SO_ERROR, "SO_ERROR"), | |
(socket.SOL_SOCKET, socket.SO_INCOMING_CPU, "SO_INCOMING_CPU"), | |
(socket.SOL_SOCKET, socket.SO_J1939_ERRQUEUE, "SO_J1939_ERRQUEUE"), | |
(socket.SOL_SOCKET, socket.SO_J1939_FILTER, "SO_J1939_FILTER"), | |
(socket.SOL_SOCKET, socket.SO_J1939_PROMISC, "SO_J1939_PROMISC"), | |
(socket.SOL_SOCKET, socket.SO_J1939_SEND_PRIO, "SO_J1939_SEND_PRIO"), | |
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, "SO_KEEPALIVE"), | |
(socket.SOL_SOCKET, socket.SO_LINGER, "SO_LINGER"), | |
(socket.SOL_SOCKET, socket.SO_MARK, "SO_MARK"), | |
(socket.SOL_SOCKET, socket.SO_OOBINLINE, "SO_OOBINLINE"), | |
(socket.SOL_SOCKET, socket.SO_PASSCRED, "SO_PASSCRED"), | |
(socket.SOL_SOCKET, socket.SO_PASSSEC, "SO_PASSSEC"), | |
(socket.SOL_SOCKET, socket.SO_PEERCRED, "SO_PEERCRED"), | |
(socket.SOL_SOCKET, socket.SO_PEERSEC, "SO_PEERSEC"), | |
(socket.SOL_SOCKET, socket.SO_PRIORITY, "SO_PRIORITY"), | |
(socket.SOL_SOCKET, socket.SO_PROTOCOL, "SO_PROTOCOL"), | |
(socket.SOL_SOCKET, socket.SO_RCVBUF, "SO_RCVBUF"), | |
(socket.SOL_SOCKET, socket.SO_RCVLOWAT, "SO_RCVLOWAT"), | |
(socket.SOL_SOCKET, socket.SO_RCVTIMEO, "SO_RCVTIMEO"), | |
(socket.SOL_SOCKET, socket.SO_REUSEADDR, "SO_REUSEADDR"), | |
(socket.SOL_SOCKET, socket.SO_REUSEPORT, "SO_REUSEPORT"), | |
(socket.SOL_SOCKET, socket.SO_SNDBUF, "SO_SNDBUF"), | |
(socket.SOL_SOCKET, socket.SO_SNDLOWAT, "SO_SNDLOWAT"), | |
(socket.SOL_SOCKET, socket.SO_SNDTIMEO, "SO_SNDTIMEO"), | |
(socket.SOL_SOCKET, socket.SO_TYPE, "SO_TYPE"), | |
( | |
socket.SOL_SOCKET, | |
socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE, | |
"SO_VM_SOCKETS_BUFFER_MAX_SIZE", | |
), | |
( | |
socket.SOL_SOCKET, | |
socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE, | |
"SO_VM_SOCKETS_BUFFER_MIN_SIZE", | |
), | |
( | |
socket.SOL_SOCKET, | |
socket.SO_VM_SOCKETS_BUFFER_SIZE, | |
"SO_VM_SOCKETS_BUFFER_SIZE", | |
), | |
(socket.SOL_TCP, socket.TCP_CONGESTION, "TCP_CONGESTION"), | |
(socket.SOL_TCP, socket.TCP_CORK, "TCP_CORK"), | |
(socket.SOL_TCP, socket.TCP_DEFER_ACCEPT, "TCP_DEFER_ACCEPT"), | |
(socket.SOL_TCP, socket.TCP_FASTOPEN, "TCP_FASTOPEN"), | |
(socket.SOL_TCP, socket.TCP_INFO, "TCP_INFO"), | |
(socket.SOL_TCP, socket.TCP_KEEPCNT, "TCP_KEEPCNT"), | |
(socket.SOL_TCP, socket.TCP_KEEPIDLE, "TCP_KEEPIDLE"), | |
(socket.SOL_TCP, socket.TCP_KEEPINTVL, "TCP_KEEPINTVL"), | |
(socket.SOL_TCP, socket.TCP_LINGER2, "TCP_LINGER2"), | |
(socket.SOL_TCP, socket.TCP_MAXSEG, "TCP_MAXSEG"), | |
(socket.SOL_TCP, socket.TCP_NODELAY, "TCP_NODELAY"), | |
(socket.SOL_TCP, socket.TCP_NOTSENT_LOWAT, "TCP_NOTSENT_LOWAT"), | |
(socket.SOL_TCP, socket.TCP_QUICKACK, "TCP_QUICKACK"), | |
(socket.SOL_TCP, socket.TCP_SYNCNT, "TCP_SYNCNT"), | |
(socket.SOL_TCP, socket.TCP_USER_TIMEOUT, "TCP_USER_TIMEOUT"), | |
(socket.SOL_TCP, socket.TCP_WINDOW_CLAMP, "TCP_WINDOW_CLAMP"), | |
) | |
def get_socket_status(sock): | |
operation_results = {} | |
for level, op, name in OPERATIONS: | |
try: | |
res = sock.getsockopt(level, op) | |
except OSError: | |
logger.debug(f"Failed to get {level=} {op=}") | |
continue | |
operation_results[name] = res | |
return operation_results |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
MIDDLEWARE += [ | |
"client_disconnect_handler.SocketMonitorMiddleware", | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment