Created
April 26, 2025 12:22
-
-
Save iamshreeram/329e8f29632300e97768bfbdc13f6e49 to your computer and use it in GitHub Desktop.
create db proxy for mssql server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################## | |
''' | |
This script creates a proxy for mssql server. | |
1. You can start the mssql server using below docker commands - | |
docker pull mcr.microsoft.com/azure-sql-edge | |
docker run -e "ACCEPT_EULA=1" -e "MSSQL_SA_PASSWORD=password" -e "MSSQL_PID=Developer" -e "MSSQL_USER=sa" -p 1433:1433 -d --name=sql mcr.microsoft.com/azure-sql-edge | |
2. In new terminal, just copy below script and run it as `python mssql_proxy_tds.py` | |
Note : | |
You need python-tds for this script to run | |
''' | |
################## | |
import socket | |
import threading | |
import struct | |
import logging | |
from logging.handlers import RotatingFileHandler | |
import pytds | |
# Configure logging | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
RotatingFileHandler('mssql_proxy.log', maxBytes=5*1024*1024, backupCount=3), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Configuration | |
SQL_SERVER_HOST = 'localhost' | |
SQL_SERVER_PORT = 1433 | |
SQL_SERVER_USERNAME = 'sa' | |
SQL_SERVER_PASSWORD = 'password' | |
SQL_SERVER_DATABASE = 'master' | |
class TDSPacketHandler: | |
"""Complete TDS protocol handler with proper packet processing""" | |
def __init__(self): | |
self.buffer = bytearray() | |
self.client_socket = None | |
self.server_conn = None | |
self.server_cursor = None | |
self.login_complete = False | |
def connect_to_server(self): | |
"""Establish connection to real SQL Server""" | |
try: | |
self.server_conn = pytds.connect( | |
server=SQL_SERVER_HOST, | |
database=SQL_SERVER_DATABASE, | |
user=SQL_SERVER_USERNAME, | |
password=SQL_SERVER_PASSWORD, | |
as_dict=False, | |
autocommit=True, | |
tds_version='7.4' | |
) | |
self.server_cursor = self.server_conn.cursor() | |
logger.info("Successfully connected to SQL Server") | |
return True | |
except Exception as e: | |
logger.error(f"SQL Server connection failed: {e}") | |
return False | |
def process_packets(self, data, client_socket): | |
"""Process incoming TDS packets""" | |
self.client_socket = client_socket | |
self.buffer.extend(data) | |
while len(self.buffer) >= 8: # Minimum TDS header size | |
# Parse TDS header | |
packet_type = self.buffer[0] | |
status = self.buffer[1] | |
length = struct.unpack('>H', self.buffer[2:4])[0] | |
# Check if we have complete packet | |
if len(self.buffer) < length: | |
break | |
packet_data = self.buffer[8:length] | |
try: | |
# Handle PreLogin (type 0x12) | |
if packet_type == 0x12: | |
logger.debug("Processing PreLogin packet") | |
self.handle_prelogin(packet_data) | |
# Handle Login (type 0x10) | |
elif packet_type == 0x10 and not self.login_complete: | |
logger.debug("Processing Login packet") | |
if self.connect_to_server(): | |
self.login_complete = True | |
# Send proper login acknowledgement | |
self.send_login_ack() | |
# Handle SQL Batch (type 0x01) | |
elif packet_type == 0x01 and self.login_complete: | |
try: | |
# Print raw bytes | |
logger.debug(f"Raw SQL Batch packet data: {packet_data.hex()}") | |
# Decode as UTF-16LE (used by TDS for SQL batch) | |
query = packet_data.decode('utf-16-le').strip() | |
logger.info(f"Received SQL query from client: {query}") | |
print(f"[SQLCMD Query Received]: {query}") | |
self.execute_query(query) | |
except Exception as e: | |
logger.error(f"Failed to decode or execute query: {e}") | |
self.client_socket.send(f"Error decoding query: {str(e)}".encode('utf-8')) | |
''' old elif | |
elif packet_type == 0x01 and self.login_complete: | |
query = packet_data.decode('utf-16-le').strip() | |
logger.info(f"Executing query: {query}") | |
self.execute_query(query) | |
''' | |
except Exception as e: | |
logger.error(f"Error processing packet: {e}") | |
self.client_socket.send(f"Error: {str(e)}".encode('utf-8')) | |
break | |
# Remove processed packet from buffer | |
del self.buffer[:length] | |
def handle_prelogin(self, packet_data): | |
"""Handle PreLogin negotiation with correct response format""" | |
# Build proper PreLogin response | |
response = bytearray([ | |
# Header | |
0x04, # Packet type (response) | |
0x01, # Status (last) | |
0x00, 0x1B, # Length (27 bytes) | |
0x00, 0x00, # SPID | |
0x01, # Packet ID | |
0x00, # Window | |
# PreLogin options | |
0x00, # Option 1 (VERSION): offset | |
0x00, 0x00, 0x00, # Option 1: length | |
0x05, # Option 2 (ENCRYPTION): offset | |
0x00, 0x00, 0x01, # Option 2: length | |
0x06, # Option 3 (INSTOPT): offset | |
0x00, 0x00, 0x01, # Option 3: length | |
0x07, # Option 4 (THREADID): offset | |
0x00, 0x00, 0x04, # Option 4: length | |
0x0B, # Option 5 (MARS): offset | |
0x00, 0x00, 0x01, # Option 5: length | |
# Data | |
0x01 # ENCRYPTION: ENCRYPT_ON (0x01) | |
]) | |
self.client_socket.send(response) | |
logger.debug("Sent PreLogin response") | |
def send_login_ack(self): | |
"""Send proper Login acknowledgement""" | |
login_ack = bytearray([ | |
# Header | |
0x04, # Packet type (TABULAR_RESULT) | |
0x01, # Status (EOM) | |
0x00, 0x4B, # Length (75 bytes) | |
0x00, 0x00, # SPID | |
0x01, # Packet ID | |
0x00, # Window | |
# Token type (LOGINACK) | |
0xAD, | |
# Length | |
0x00, 0x42, | |
# Interface (TDS 7.4) | |
0x00, 0x00, 0x74, 0x00, 0x64, 0x00, 0x73, 0x00, 0x37, 0x00, | |
0x2E, 0x00, 0x34, 0x00, 0x00, 0x00, | |
# TDS version | |
0x07, 0x00, 0x00, 0x04, | |
# Program name | |
0x4D, 0x00, 0x69, 0x00, 0x63, 0x00, 0x72, 0x00, 0x6F, 0x00, | |
0x73, 0x00, 0x6F, 0x00, 0x66, 0x00, 0x74, 0x00, 0x20, 0x00, | |
0x53, 0x00, 0x51, 0x00, 0x4C, 0x00, 0x20, 0x00, 0x53, 0x00, | |
0x65, 0x00, 0x72, 0x00, 0x76, 0x00, 0x65, 0x00, 0x72, 0x00, | |
# Version (SQL Server 2019) | |
0x0F, 0x00, 0x07, 0x00, 0x0D, 0x00 | |
]) | |
self.client_socket.send(login_ack) | |
logger.debug("Sent Login acknowledgement") | |
''' | |
def execute_query(self, query): | |
"""Execute query and send results back to client""" | |
try: | |
self.server_cursor.execute(query) | |
if self.server_cursor.description: | |
rows = self.server_cursor.fetchall() | |
response = self.format_results(rows, self.server_cursor.description) | |
self.client_socket.send(response.encode('utf-8')) | |
else: | |
self.client_socket.send(b"Query executed successfully") | |
except Exception as e: | |
error_msg = f"Error executing query: {str(e)}" | |
logger.error(error_msg) | |
self.client_socket.send(error_msg.encode('utf-8')) | |
''' | |
def execute_query(self, query): | |
"""Execute query and send results back to client""" | |
try: | |
logger.debug(f"Forwarding query to SQL Server: {query}") | |
self.server_cursor.execute(query) | |
if self.server_cursor.description: | |
rows = self.server_cursor.fetchall() | |
response = self.format_results(rows, self.server_cursor.description) | |
self.client_socket.send(response.encode('utf-8')) | |
else: | |
self.client_socket.send(b"Query executed successfully") | |
except Exception as e: | |
error_msg = f"Error executing query: {str(e)}" | |
logger.error(error_msg) | |
self.client_socket.send(error_msg.encode('utf-8')) | |
def format_results(self, rows, description): | |
"""Format query results as tabular text""" | |
if not rows: | |
return "No results returned" | |
# Get column names and widths | |
columns = [col[0] for col in description] | |
col_widths = [len(str(col)) for col in columns] | |
# Calculate column widths | |
for row in rows: | |
for i, val in enumerate(row): | |
col_widths[i] = max(col_widths[i], len(str(val))) | |
# Build header | |
header = " | ".join(f"{col:<{width}}" for col, width in zip(columns, col_widths)) | |
separator = "-+-".join("-" * width for width in col_widths) | |
lines = [header, separator] | |
# Add rows | |
for row in rows: | |
line = " | ".join(f"{str(val):<{width}}" for val, width in zip(row, col_widths)) | |
lines.append(line) | |
return "\n".join(lines) + "\n" | |
def handle_client(client_socket): | |
logger.info(f"New connection from {client_socket.getpeername()}") | |
handler = TDSPacketHandler() | |
try: | |
while True: | |
# Receive data | |
data = client_socket.recv(4096) | |
if not data: | |
break | |
# Process TDS packets | |
handler.process_packets(data, client_socket) | |
except ConnectionResetError: | |
logger.info("Client disconnected") | |
except Exception as e: | |
logger.error(f"Connection error: {e}") | |
finally: | |
if hasattr(handler, 'server_conn') and handler.server_conn: | |
handler.server_conn.close() | |
client_socket.close() | |
logger.info("Connection closed") | |
def start_proxy_server(): | |
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
server_socket.bind(('0.0.0.0', 1434)) | |
server_socket.listen(5) | |
logger.info("Proxy server listening on port 1434...") | |
try: | |
while True: | |
client_socket, addr = server_socket.accept() | |
logger.info(f"Accepted connection from {addr}") | |
threading.Thread( | |
target=handle_client, | |
args=(client_socket,), | |
daemon=True | |
).start() | |
except KeyboardInterrupt: | |
logger.info("Shutting down proxy server") | |
finally: | |
server_socket.close() | |
if __name__ == '__main__': | |
start_proxy_server() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment