Skip to content

Instantly share code, notes, and snippets.

@Kenan7
Last active October 5, 2024 17:04
Show Gist options
  • Save Kenan7/6248374de913f5a6a06c891c8e0b3858 to your computer and use it in GitHub Desktop.
Save Kenan7/6248374de913f5a6a06c891c8e0b3858 to your computer and use it in GitHub Desktop.
middleware
import cProfile
import pstats
import io
import traceback
import time
import logging
from django.conf import settings
from django.db import connection, reset_queries
from django.utils.deprecation import MiddlewareMixin
class ProfileMiddleware(MiddlewareMixin):
def process_view(self, request, view_func, view_args, view_kwargs):
# If profiling is disabled for this request, skip profiling
if not request.GET.get('prof'):
return None
# Start profiling
self.profiler = cProfile.Profile()
self.profiler.enable()
return None
def process_response(self, request, response):
# If profiling was not enabled, skip processing the response
if not hasattr(self, 'profiler'):
return response
try:
prof_value = int(request.GET.get('prof'))
prof_value = min(max(1, prof_value), 200)
except (ValueError, TypeError):
prof_value = 20
# Stop profiling
self.profiler.disable()
# Output profiling data
s = io.StringIO()
ps = pstats.Stats(self.profiler, stream=s).sort_stats(pstats.SortKey.CUMULATIVE)
ps.print_stats(prof_value)
# Add profiling data to the response
response.content += f"\n\n{str(s.getvalue())}".encode('utf-8')
return response
logger = logging.getLogger('django')
import threading
class QueryLoggingMiddleware(MiddlewareMixin):
def process_request(self, request):
# Reset queries and start time tracking
reset_queries()
self.start_time = time.time()
def process_response(self, request, response):
# Calculate total time spent on the request
total_time = time.time() - self.start_time
# Get queries executed
queries = connection.queries
total_queries = len(queries)
total_sql_time = sum(float(query['time']) for query in queries)
# Log summary
logger.info(f"Total Queries: {total_queries}")
logger.info(f"Total SQL Time: {total_sql_time:.3f}s")
logger.info(f"Total Request Time: {total_time:.3f}s")
for query in queries:
# Capture the SQL and time
logger.debug(f"Query: {query['sql']} took {query['time']} seconds")
# Get the traceback for query source
stack_trace = traceback.format_stack()
logger.debug(f"Query Source: {''.join(stack_trace)}")
return response
# Thread-local storage to track middleware processing state
thread_local = threading.local()
class DuplicateQueryDetectionMiddleware(MiddlewareMixin):
def process_request(self, request):
# Start the timer to capture request duration
request.start_time = time.time()
request.endpoint = request.path
request.method = request.method
request.params = request.GET.dict()
request.queries = {}
request.duplicate_queries = {}
request.query_tracebacks = {}
def process_response(self, request, response):
total_time = time.time() - request.start_time
query_count = len(connection.queries)
# Track and analyze queries
for query in connection.queries:
sql = query['sql']
normalized_sql = self.normalize_query(sql)
# Capture a truncated traceback (last 2-3 function calls)
stack = traceback.extract_stack()[:-1] # Avoid this function itself in traceback
truncated_traceback = stack[-3:] # Keep only the last 3 calls
# Format the traceback
formatted_traceback = ''.join(traceback.format_list(truncated_traceback))
# Count occurrences of each query
if normalized_sql in request.queries:
request.queries[normalized_sql]['count'] += 1
request.duplicate_queries[normalized_sql] = request.queries[normalized_sql]
else:
request.queries[normalized_sql] = {'count': 1, 'traceback': formatted_traceback}
# Write detailed information to a file
self.write_request_info_to_file(request, query_count, total_time)
return response
def normalize_query(self, sql):
return ' '.join(sql.split())
def truncate_traceback(self, stack):
filtered_stack = []
# Filter out entries that are related to Python system, middleware, and Django internals
for line in stack:
if '/site-packages/' not in line and '/opt/homebrew/' not in line and '/wsgiref/' not in line:
filtered_stack.append(line)
# Keep the last 5 entries or fewer if fewer relevant lines are available
return filtered_stack[-5:] if len(filtered_stack) > 5 else filtered_stack
def write_request_info_to_file(self, request, query_count, total_time):
output_file = settings.BASE_DIR / 'duplicate_query_log.txt'
log_entry = (
f"--- New Request ---\n"
f"Endpoint: {request.endpoint}\n"
f"Method: {request.method}\n"
f"Parameters: {request.params}\n"
f"Total Queries: {query_count}\n"
f"Request Duration: {total_time:.4f} seconds\n"
f"Duplicate Queries:\n"
)
if request.duplicate_queries:
for query, data in request.duplicate_queries.items():
if data['count'] > 1:
log_entry += f" - Query: {query}\n"
log_entry += f" Count: {data['count']}\n"
log_entry += " Traceback:\n" + ''.join(data['traceback']) + "\n"
else:
log_entry += " - No duplicate queries found.\n"
log_entry += "\n\n\n" # Add space between requests
# Write to the file
with open(output_file, 'a') as f:
f.write(log_entry)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment