Last active
October 5, 2024 17:04
-
-
Save Kenan7/6248374de913f5a6a06c891c8e0b3858 to your computer and use it in GitHub Desktop.
middleware
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cProfile | |
import pstats | |
import io | |
import traceback | |
import time | |
import logging | |
from django.conf import settings | |
from django.db import connection, reset_queries | |
from django.utils.deprecation import MiddlewareMixin | |
class ProfileMiddleware(MiddlewareMixin): | |
def process_view(self, request, view_func, view_args, view_kwargs): | |
# If profiling is disabled for this request, skip profiling | |
if not request.GET.get('prof'): | |
return None | |
# Start profiling | |
self.profiler = cProfile.Profile() | |
self.profiler.enable() | |
return None | |
def process_response(self, request, response): | |
# If profiling was not enabled, skip processing the response | |
if not hasattr(self, 'profiler'): | |
return response | |
try: | |
prof_value = int(request.GET.get('prof')) | |
prof_value = min(max(1, prof_value), 200) | |
except (ValueError, TypeError): | |
prof_value = 20 | |
# Stop profiling | |
self.profiler.disable() | |
# Output profiling data | |
s = io.StringIO() | |
ps = pstats.Stats(self.profiler, stream=s).sort_stats(pstats.SortKey.CUMULATIVE) | |
ps.print_stats(prof_value) | |
# Add profiling data to the response | |
response.content += f"\n\n{str(s.getvalue())}".encode('utf-8') | |
return response | |
logger = logging.getLogger('django') | |
import threading | |
class QueryLoggingMiddleware(MiddlewareMixin): | |
def process_request(self, request): | |
# Reset queries and start time tracking | |
reset_queries() | |
self.start_time = time.time() | |
def process_response(self, request, response): | |
# Calculate total time spent on the request | |
total_time = time.time() - self.start_time | |
# Get queries executed | |
queries = connection.queries | |
total_queries = len(queries) | |
total_sql_time = sum(float(query['time']) for query in queries) | |
# Log summary | |
logger.info(f"Total Queries: {total_queries}") | |
logger.info(f"Total SQL Time: {total_sql_time:.3f}s") | |
logger.info(f"Total Request Time: {total_time:.3f}s") | |
for query in queries: | |
# Capture the SQL and time | |
logger.debug(f"Query: {query['sql']} took {query['time']} seconds") | |
# Get the traceback for query source | |
stack_trace = traceback.format_stack() | |
logger.debug(f"Query Source: {''.join(stack_trace)}") | |
return response | |
# Thread-local storage to track middleware processing state | |
thread_local = threading.local() | |
class DuplicateQueryDetectionMiddleware(MiddlewareMixin): | |
def process_request(self, request): | |
# Start the timer to capture request duration | |
request.start_time = time.time() | |
request.endpoint = request.path | |
request.method = request.method | |
request.params = request.GET.dict() | |
request.queries = {} | |
request.duplicate_queries = {} | |
request.query_tracebacks = {} | |
def process_response(self, request, response): | |
total_time = time.time() - request.start_time | |
query_count = len(connection.queries) | |
# Track and analyze queries | |
for query in connection.queries: | |
sql = query['sql'] | |
normalized_sql = self.normalize_query(sql) | |
# Capture a truncated traceback (last 2-3 function calls) | |
stack = traceback.extract_stack()[:-1] # Avoid this function itself in traceback | |
truncated_traceback = stack[-3:] # Keep only the last 3 calls | |
# Format the traceback | |
formatted_traceback = ''.join(traceback.format_list(truncated_traceback)) | |
# Count occurrences of each query | |
if normalized_sql in request.queries: | |
request.queries[normalized_sql]['count'] += 1 | |
request.duplicate_queries[normalized_sql] = request.queries[normalized_sql] | |
else: | |
request.queries[normalized_sql] = {'count': 1, 'traceback': formatted_traceback} | |
# Write detailed information to a file | |
self.write_request_info_to_file(request, query_count, total_time) | |
return response | |
def normalize_query(self, sql): | |
return ' '.join(sql.split()) | |
def truncate_traceback(self, stack): | |
filtered_stack = [] | |
# Filter out entries that are related to Python system, middleware, and Django internals | |
for line in stack: | |
if '/site-packages/' not in line and '/opt/homebrew/' not in line and '/wsgiref/' not in line: | |
filtered_stack.append(line) | |
# Keep the last 5 entries or fewer if fewer relevant lines are available | |
return filtered_stack[-5:] if len(filtered_stack) > 5 else filtered_stack | |
def write_request_info_to_file(self, request, query_count, total_time): | |
output_file = settings.BASE_DIR / 'duplicate_query_log.txt' | |
log_entry = ( | |
f"--- New Request ---\n" | |
f"Endpoint: {request.endpoint}\n" | |
f"Method: {request.method}\n" | |
f"Parameters: {request.params}\n" | |
f"Total Queries: {query_count}\n" | |
f"Request Duration: {total_time:.4f} seconds\n" | |
f"Duplicate Queries:\n" | |
) | |
if request.duplicate_queries: | |
for query, data in request.duplicate_queries.items(): | |
if data['count'] > 1: | |
log_entry += f" - Query: {query}\n" | |
log_entry += f" Count: {data['count']}\n" | |
log_entry += " Traceback:\n" + ''.join(data['traceback']) + "\n" | |
else: | |
log_entry += " - No duplicate queries found.\n" | |
log_entry += "\n\n\n" # Add space between requests | |
# Write to the file | |
with open(output_file, 'a') as f: | |
f.write(log_entry) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment