Skip to content

Instantly share code, notes, and snippets.

@txomon
Created February 1, 2025 18:16
Show Gist options
  • Save txomon/ce8bc4959461f64102403bdc96571760 to your computer and use it in GitHub Desktop.
Save txomon/ce8bc4959461f64102403bdc96571760 to your computer and use it in GitHub Desktop.
Detect client-side disconnection of socket to raise an exception on the server to interrupt traffic
import ctypes
import enum
import logging
import os
import socket
import struct
import sys
try:
import tenacity.retry
except ImportError:
tenacity = None
import threading
import time
import typing
import weakref
import django.http as http
from django.utils.deprecation import MiddlewareMixin
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Support for wsgiref is always there
CLIENT_DISCONNECT_EXCEPTIONS = [ConnectionAbortedError]
# Support for gunicorn
try:
import gunicorn.http.errors as gunicorn_errors
CLIENT_DISCONNECT_EXCEPTIONS.append(gunicorn_errors.NoMoreData)
except ImportError:
pass
class ClientClosedConnectionError(*CLIENT_DISCONNECT_EXCEPTIONS):
"""This class needs to be a subclass of all the Errors and Exceptions that signify a
client closing the connection in the web worker managers.
Right now it works for:
Gunicorn: https://github.com/benoitc/gunicorn/blob/ab9c8301cb9ae573ba597154ddeea16f0326fc15/gunicorn/workers/gthread.py#L285
wsgiref (AKA django's manage.py runserver): https://github.com/python/cpython/blob/16c9415fba4972743f1944ebc44946e475e68bc4/Lib/wsgiref/handlers.py#L139
"""
pass
# tcpi_state values taken from https://github.com/torvalds/linux/blob/master/include/net/tcp_states.h
# explanations https://github.com/torvalds/linux/blob/52a93d39b17dc7eb98b6aa3edb93943248e03b2f/net/ipv4/tcp.c#L209-L241
class TcpiState(enum.IntEnum):
TCP_ESTABLISHED = 1
TCP_SYN_SENT = enum.auto()
TCP_SYN_RECV = enum.auto()
TCP_FIN_WAIT1 = enum.auto()
TCP_FIN_WAIT2 = enum.auto()
TCP_TIME_WAIT = enum.auto()
TCP_CLOSE = enum.auto()
TCP_CLOSE_WAIT = enum.auto()
TCP_LAST_ACK = enum.auto()
TCP_LISTEN = enum.auto()
TCP_CLOSING = enum.auto()
TCP_NEW_SYN_RECV = enum.auto()
TCP_MAX_STATES = enum.auto()
# Precompute the set of states that are closing states for us
CLOSING_STATES = (
TcpiState.TCP_FIN_WAIT1,
TcpiState.TCP_FIN_WAIT2,
TcpiState.TCP_TIME_WAIT,
TcpiState.TCP_CLOSE,
TcpiState.TCP_CLOSE_WAIT,
TcpiState.TCP_LAST_ACK,
TcpiState.TCP_CLOSING,
)
def is_closed(sock):
"""
Checks through the TCP layer socket whether the socket is still active
or not. It's reliable however it's extremely slow (0.5s~)
"""
# If the socket is marked as closed, don't even bother
if getattr(sock, "_closed", False):
return True
try:
# We don't need the full struct tcp_info buffer because tcpi_state
# is the first attribute in the struct, and occupies a single byte
tcpi_state = sock.getsockopt(socket.SOL_TCP, socket.TCP_INFO)
# For whatever reason, the system call might fail, we will assume this means
# the socket is closed, given that it's usually due to some abnormal state
except OSError:
return True
return tcpi_state in CLOSING_STATES
def raise_client_disconnect_error_in_thread(thread: threading.Thread):
# Ideas taken from:
# https://stackoverflow.com/questions/36484151/throw-an-exception-into-another-thread
# https://gist.github.com/liuw/2407154
# We are all aware it's a hideous hack, but there are no contexts
# like in golang, so...
ret = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident),
ctypes.py_object(ClientClosedConnectionError),
)
if ret == 0:
raise ValueError(f"Invalid thread identifier {thread.ident}")
elif ret == 1:
pass
elif ret > 1:
# If we failed to SetAsyncExc, just in case, we set it back to none
# so that weird stuff might not happen, not sure exactly why
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread.ident), 0)
def after_retry_logger(retry_state):
logger.critical(
f"Library error, please report bug. Retrying {retry_state.fn} "
f"attempt {retry_state.attempt_number} ended with: {retry_state.outcome}",
exc_info=retry_state.outcome.exception(),
)
if tenacity:
class retry_exceptions(tenacity.retry_base):
"""Retry strategy that only rejects exceptions."""
def __call__(self, retry_state: "tenacity.RetryCallState") -> bool:
if retry_state.outcome is None:
raise RuntimeError("__call__() called before outcome was set")
return bool(retry_state.outcome.failed)
class ClientConnectionMonitor(threading.Thread):
def __init__(self, *a, **kw):
# Because is_closed() is so slow, we can't have a global lock, as
# it would add (num_sockets * 0.5s)/2 delay in average to registering
# new sockets. We use registration_lock to make sure that registering
# new thread_locks is thread safe. We need to make sure however that
# reading thread_locks is never done directly (but by copy) because
# thread_locks is a WeakKeyDictionary that might be purged in the middle
# of iteration
# The block that registration_lock is used on should only be for writing
# or copying registration_lock, and never anything else
self.registration_lock = threading.Lock()
# With the registration lock we create a lock per thread. And that lock is
# saved in thread_locks so that socket register / deregister happens
# synchronously. This will allow us to have a synchronous block per thread.
# However, this block can't be used for an is_closed() because it will
# damage performance.
#
# Example: If we use this lock to de/register a socket on the side
# of the web worker thread, and for a is_closed() + raise_exc() call,
# the average time it will take to acquire the lock is (0.5s/2)/num_socks
# being 0.5s the worst case scenario, which is terribly bad. Therefore,
# we need to be careful and make sure that we never call is_closed()
# from withing a thread_locks protected block.
self.thread_locks: typing.Dict[
threading.Thread, threading.Lock
] = weakref.WeakKeyDictionary()
# We protect thread_watchers with the above thread_locks. We need to
# record all the sockets we are watching, max one per thread. Removal of
# a socket from thread_watchers can happen through weak key recycling or
# by unregistering them manually
self.thread_watchers: typing.Dict[
threading.Thread,
socket.socket,
] = weakref.WeakKeyDictionary()
# We only want to raise an exception once per socket, hence we need to
# keep track which socket was the ones we raised an exception at.
self.sockets_raised = weakref.WeakSet()
kw.setdefault("name", "ClientConnectionMonitor")
super().__init__(*a, **kw)
def _run(self):
while not self._is_stopped:
with self.registration_lock:
# Registration lock is necessary to add thread_watchers,
# however because it's a weakkeyref, it might be deleted
# asynchronously, so always make a copy. Also, StackOverflow lies:
# https://stackoverflow.com/questions/12428026/safely-iterating-over-weakkeydictionary-and-weakvaluedictionary
# Making this copy will ensure hard references to the threads exist
thread_watchers = dict(self.thread_watchers)
for thread, sock in thread_watchers.items():
# We don't need a lock for sockets_raised because it's a thread
# local resource, and is_closed() needs to be run outside of the
# thread_lock
if sock in self.sockets_raised or not is_closed(sock):
continue
thread_lock = self.thread_locks[thread]
# We can't call is_closed() asynchronously because it will block
# the request flow. However, we do need to make sure that there is
# a before and after the socket deregistration, so that we don't
# raise a ClientClosedConnectionError after the views have been
# called.
# Even if we did, it wouldn't be a huge problem because the exception
# we raise is one that is handled, but it would break stuff like
# tracing if any, so it's better to just raise the exception
# a single time
with thread_lock:
registered_sock = self.thread_watchers.get(thread, None)
if registered_sock == sock:
raise_client_disconnect_error_in_thread(thread)
self.sockets_raised.add(sock)
time.sleep(0.2)
logger.error("Stopping work on thread shutdown")
if tenacity:
@tenacity.retry(
wait=tenacity.wait_fixed(1),
retry=retry_exceptions,
stop=tenacity.stop_never,
after=after_retry_logger,
)
def run(self):
self._run()
else:
def run(self):
self._run()
def register_socket(self, sock: socket.socket, thread: threading.Thread = None):
logger.debug(f"Register {sock=}")
thread = thread or threading.current_thread()
thread_lock = self.thread_locks.get(thread)
if not thread_lock:
with self.registration_lock:
thread_lock = self.thread_locks.get(thread)
if not thread_lock:
thread_lock = self.thread_locks[thread] = threading.Lock()
with thread_lock:
self.thread_watchers[thread] = sock
def deregister_socket(self, sock, thread=None):
logger.debug(f"Deregister {sock=}")
thread = thread or threading.current_thread()
thread_lock = self.thread_locks.get(thread)
if not thread_lock:
raise RuntimeError("Trying to unregister an already unregistered socket")
with thread_lock:
existing_socket = self.thread_watchers.pop(thread)
if sock != existing_socket:
logger.error(
f"Removing different socket to registered {sock=} VS {existing_socket=}"
)
CONNECTION_MONITOR = {}
def get_client_connection_monitor():
global CONNECTION_MONITOR
connection_monitor = CONNECTION_MONITOR.get(os.getpid())
if connection_monitor:
return connection_monitor
logger.debug("Initializing worker thread")
connection_monitor = ClientConnectionMonitor()
# With this implementation we don't leak memory
CONNECTION_MONITOR = {os.getpid(): connection_monitor}
connection_monitor.daemon = True
connection_monitor.start()
return connection_monitor
def get_socket(request):
# If using gunicorn
if "gunicorn.socket" in request.environ:
return request.environ["gunicorn.socket"]
# If using runserver from django
f = sys._getframe().f_back
for _ in range(500):
for k, v in f.f_locals.items():
if isinstance(v, socket.socket):
return v
f = f.f_back
if not f:
raise ValueError("Reached the end with no eligible frame")
else:
raise ValueError("No frame holding the socket is available")
class SocketMonitorMiddleware(MiddlewareMixin):
@staticmethod
def process_request(request):
sock = get_socket(request)
connection_monitor = get_client_connection_monitor()
connection_monitor.register_socket(sock)
@staticmethod
def process_response(request, response):
try:
sock = get_socket(request)
connection_monitor = get_client_connection_monitor()
connection_monitor.deregister_socket(sock)
except ConnectionAbortedError:
pass
return response
@staticmethod
def process_exception(request, exception):
if isinstance(exception, ConnectionAbortedError):
return http.HttpResponse(
status=499,
reason="Client closed connection",
)
# Debugging utils
# Struct extracted from https://github.com/torvalds/linux/blob/master/include/uapi/linux/tcp.h#L214
TCP_INFO_STRUCT = (
("B", "tcpi_state"),
("B", "tcpi_ca_state"),
("B", "tcpi_retransmits"),
("B", "tcpi_probes"),
("B", "tcpi_backoff"),
("B", "tcpi_options"),
("B", "tcpi_features1"), # ("B", "tcpi_snd_wscale : 4, tcpi_rcv_wscale : 4"),
(
"B",
"tcpi_features2",
), # ("B", "tcpi_delivery_rate_app_limited:1, tcpi_fastopen_client_fail:2"),
("I", "tcpi_rto"),
("I", "tcpi_ato"),
("I", "tcpi_snd_mss"),
("I", "tcpi_rcv_mss"),
("I", "tcpi_unacked"),
("I", "tcpi_sacked"),
("I", "tcpi_lost"),
("I", "tcpi_retrans"),
("I", "tcpi_fackets"),
("I", "tcpi_last_data_sent"),
("I", "tcpi_last_ack_sent"),
("I", "tcpi_last_data_recv"),
("I", "tcpi_last_ack_recv"),
("I", "tcpi_pmtu"),
("I", "tcpi_rcv_ssthresh"),
("I", "tcpi_rtt"),
("I", "tcpi_rttvar"),
("I", "tcpi_snd_ssthresh"),
("I", "tcpi_snd_cwnd"),
("I", "tcpi_advmss"),
("I", "tcpi_reordering"),
("I", "tcpi_rcv_rtt"),
("I", "tcpi_rcv_space"),
("I", "tcpi_total_retrans"),
("Q", "tcpi_pacing_rate"),
("Q", "tcpi_max_pacing_rate"),
("Q", "tcpi_bytes_acked"),
("Q", "tcpi_bytes_received"),
("I", "tcpi_segs_out"),
("I", "tcpi_segs_in"),
("I", "tcpi_notsent_bytes"),
("I", "tcpi_min_rtt"),
("I", "tcpi_data_segs_in"),
("I", "tcpi_data_segs_out"),
("Q", "tcpi_delivery_rate"),
("Q", "tcpi_busy_time"),
("Q", "tcpi_rwnd_limited"),
("Q", "tcpi_sndbuf_limited"),
("I", "tcpi_delivered"),
("I", "tcpi_delivered_ce"),
("Q", "tcpi_bytes_sent"),
("Q", "tcpi_bytes_retrans"),
("I", "tcpi_dsack_dups"),
("I", "tcpi_reord_seen"),
("I", "tcpi_rcv_ooopack"),
("I", "tcpi_snd_wnd"),
("I", "tcpi_rcv_wnd"),
("I", "tcpi_rehash"),
)
TCP_INFO_STRUCT_FORMAT_STRING = "@" + "".join(x for x, _ in TCP_INFO_STRUCT)
TCP_INFO_STRUCT_SIZE = struct.calcsize(TCP_INFO_STRUCT_FORMAT_STRING)
def get_tcp_info(sock):
tcp_info_struct = sock.getsockopt(
socket.SOL_TCP, socket.TCP_INFO, TCP_INFO_STRUCT_SIZE
)
tcp_info_struct_values = struct.unpack(
TCP_INFO_STRUCT_FORMAT_STRING, tcp_info_struct
)
tcp_info = {}
for value, (_, name) in zip(tcp_info_struct_values, TCP_INFO_STRUCT):
tcp_info[name] = value
return tcp_info
def print_tcp_info(sock):
tcp_info = get_tcp_info(sock)
print(f"{tcp_info=}")
OPERATIONS = (
(socket.SOL_SOCKET, socket.SO_ERROR, "SO_ERROR"),
(socket.SOL_SOCKET, socket.SOMAXCONN, "SOMAXCONN"),
(socket.SOL_SOCKET, socket.SO_ACCEPTCONN, "SO_ACCEPTCONN"),
(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, "SO_BINDTODEVICE"),
(socket.SOL_SOCKET, socket.SO_BROADCAST, "SO_BROADCAST"),
(socket.SOL_SOCKET, socket.SO_DEBUG, "SO_DEBUG"),
(socket.SOL_SOCKET, socket.SO_DOMAIN, "SO_DOMAIN"),
(socket.SOL_SOCKET, socket.SO_DONTROUTE, "SO_DONTROUTE"),
(socket.SOL_SOCKET, socket.SO_ERROR, "SO_ERROR"),
(socket.SOL_SOCKET, socket.SO_INCOMING_CPU, "SO_INCOMING_CPU"),
(socket.SOL_SOCKET, socket.SO_J1939_ERRQUEUE, "SO_J1939_ERRQUEUE"),
(socket.SOL_SOCKET, socket.SO_J1939_FILTER, "SO_J1939_FILTER"),
(socket.SOL_SOCKET, socket.SO_J1939_PROMISC, "SO_J1939_PROMISC"),
(socket.SOL_SOCKET, socket.SO_J1939_SEND_PRIO, "SO_J1939_SEND_PRIO"),
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, "SO_KEEPALIVE"),
(socket.SOL_SOCKET, socket.SO_LINGER, "SO_LINGER"),
(socket.SOL_SOCKET, socket.SO_MARK, "SO_MARK"),
(socket.SOL_SOCKET, socket.SO_OOBINLINE, "SO_OOBINLINE"),
(socket.SOL_SOCKET, socket.SO_PASSCRED, "SO_PASSCRED"),
(socket.SOL_SOCKET, socket.SO_PASSSEC, "SO_PASSSEC"),
(socket.SOL_SOCKET, socket.SO_PEERCRED, "SO_PEERCRED"),
(socket.SOL_SOCKET, socket.SO_PEERSEC, "SO_PEERSEC"),
(socket.SOL_SOCKET, socket.SO_PRIORITY, "SO_PRIORITY"),
(socket.SOL_SOCKET, socket.SO_PROTOCOL, "SO_PROTOCOL"),
(socket.SOL_SOCKET, socket.SO_RCVBUF, "SO_RCVBUF"),
(socket.SOL_SOCKET, socket.SO_RCVLOWAT, "SO_RCVLOWAT"),
(socket.SOL_SOCKET, socket.SO_RCVTIMEO, "SO_RCVTIMEO"),
(socket.SOL_SOCKET, socket.SO_REUSEADDR, "SO_REUSEADDR"),
(socket.SOL_SOCKET, socket.SO_REUSEPORT, "SO_REUSEPORT"),
(socket.SOL_SOCKET, socket.SO_SNDBUF, "SO_SNDBUF"),
(socket.SOL_SOCKET, socket.SO_SNDLOWAT, "SO_SNDLOWAT"),
(socket.SOL_SOCKET, socket.SO_SNDTIMEO, "SO_SNDTIMEO"),
(socket.SOL_SOCKET, socket.SO_TYPE, "SO_TYPE"),
(
socket.SOL_SOCKET,
socket.SO_VM_SOCKETS_BUFFER_MAX_SIZE,
"SO_VM_SOCKETS_BUFFER_MAX_SIZE",
),
(
socket.SOL_SOCKET,
socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE,
"SO_VM_SOCKETS_BUFFER_MIN_SIZE",
),
(
socket.SOL_SOCKET,
socket.SO_VM_SOCKETS_BUFFER_SIZE,
"SO_VM_SOCKETS_BUFFER_SIZE",
),
(socket.SOL_TCP, socket.TCP_CONGESTION, "TCP_CONGESTION"),
(socket.SOL_TCP, socket.TCP_CORK, "TCP_CORK"),
(socket.SOL_TCP, socket.TCP_DEFER_ACCEPT, "TCP_DEFER_ACCEPT"),
(socket.SOL_TCP, socket.TCP_FASTOPEN, "TCP_FASTOPEN"),
(socket.SOL_TCP, socket.TCP_INFO, "TCP_INFO"),
(socket.SOL_TCP, socket.TCP_KEEPCNT, "TCP_KEEPCNT"),
(socket.SOL_TCP, socket.TCP_KEEPIDLE, "TCP_KEEPIDLE"),
(socket.SOL_TCP, socket.TCP_KEEPINTVL, "TCP_KEEPINTVL"),
(socket.SOL_TCP, socket.TCP_LINGER2, "TCP_LINGER2"),
(socket.SOL_TCP, socket.TCP_MAXSEG, "TCP_MAXSEG"),
(socket.SOL_TCP, socket.TCP_NODELAY, "TCP_NODELAY"),
(socket.SOL_TCP, socket.TCP_NOTSENT_LOWAT, "TCP_NOTSENT_LOWAT"),
(socket.SOL_TCP, socket.TCP_QUICKACK, "TCP_QUICKACK"),
(socket.SOL_TCP, socket.TCP_SYNCNT, "TCP_SYNCNT"),
(socket.SOL_TCP, socket.TCP_USER_TIMEOUT, "TCP_USER_TIMEOUT"),
(socket.SOL_TCP, socket.TCP_WINDOW_CLAMP, "TCP_WINDOW_CLAMP"),
)
def get_socket_status(sock):
operation_results = {}
for level, op, name in OPERATIONS:
try:
res = sock.getsockopt(level, op)
except OSError:
logger.debug(f"Failed to get {level=} {op=}")
continue
operation_results[name] = res
return operation_results
MIDDLEWARE += [
"client_disconnect_handler.SocketMonitorMiddleware",
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment