Skip to content

Instantly share code, notes, and snippets.

@iamshreeram
Created April 26, 2025 12:22
Show Gist options
  • Save iamshreeram/329e8f29632300e97768bfbdc13f6e49 to your computer and use it in GitHub Desktop.
Save iamshreeram/329e8f29632300e97768bfbdc13f6e49 to your computer and use it in GitHub Desktop.
create db proxy for mssql server
##################
'''
This script creates a proxy for mssql server.
1. You can start the mssql server using below docker commands -
docker pull mcr.microsoft.com/azure-sql-edge
docker run -e "ACCEPT_EULA=1" -e "MSSQL_SA_PASSWORD=password" -e "MSSQL_PID=Developer" -e "MSSQL_USER=sa" -p 1433:1433 -d --name=sql mcr.microsoft.com/azure-sql-edge
2. In new terminal, just copy below script and run it as `python mssql_proxy_tds.py`
Note :
You need python-tds for this script to run
'''
##################
import socket
import threading
import struct
import logging
from logging.handlers import RotatingFileHandler
import pytds
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
RotatingFileHandler('mssql_proxy.log', maxBytes=5*1024*1024, backupCount=3),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Configuration
SQL_SERVER_HOST = 'localhost'
SQL_SERVER_PORT = 1433
SQL_SERVER_USERNAME = 'sa'
SQL_SERVER_PASSWORD = 'password'
SQL_SERVER_DATABASE = 'master'
class TDSPacketHandler:
"""Complete TDS protocol handler with proper packet processing"""
def __init__(self):
self.buffer = bytearray()
self.client_socket = None
self.server_conn = None
self.server_cursor = None
self.login_complete = False
def connect_to_server(self):
"""Establish connection to real SQL Server"""
try:
self.server_conn = pytds.connect(
server=SQL_SERVER_HOST,
database=SQL_SERVER_DATABASE,
user=SQL_SERVER_USERNAME,
password=SQL_SERVER_PASSWORD,
as_dict=False,
autocommit=True,
tds_version='7.4'
)
self.server_cursor = self.server_conn.cursor()
logger.info("Successfully connected to SQL Server")
return True
except Exception as e:
logger.error(f"SQL Server connection failed: {e}")
return False
def process_packets(self, data, client_socket):
"""Process incoming TDS packets"""
self.client_socket = client_socket
self.buffer.extend(data)
while len(self.buffer) >= 8: # Minimum TDS header size
# Parse TDS header
packet_type = self.buffer[0]
status = self.buffer[1]
length = struct.unpack('>H', self.buffer[2:4])[0]
# Check if we have complete packet
if len(self.buffer) < length:
break
packet_data = self.buffer[8:length]
try:
# Handle PreLogin (type 0x12)
if packet_type == 0x12:
logger.debug("Processing PreLogin packet")
self.handle_prelogin(packet_data)
# Handle Login (type 0x10)
elif packet_type == 0x10 and not self.login_complete:
logger.debug("Processing Login packet")
if self.connect_to_server():
self.login_complete = True
# Send proper login acknowledgement
self.send_login_ack()
# Handle SQL Batch (type 0x01)
elif packet_type == 0x01 and self.login_complete:
try:
# Print raw bytes
logger.debug(f"Raw SQL Batch packet data: {packet_data.hex()}")
# Decode as UTF-16LE (used by TDS for SQL batch)
query = packet_data.decode('utf-16-le').strip()
logger.info(f"Received SQL query from client: {query}")
print(f"[SQLCMD Query Received]: {query}")
self.execute_query(query)
except Exception as e:
logger.error(f"Failed to decode or execute query: {e}")
self.client_socket.send(f"Error decoding query: {str(e)}".encode('utf-8'))
''' old elif
elif packet_type == 0x01 and self.login_complete:
query = packet_data.decode('utf-16-le').strip()
logger.info(f"Executing query: {query}")
self.execute_query(query)
'''
except Exception as e:
logger.error(f"Error processing packet: {e}")
self.client_socket.send(f"Error: {str(e)}".encode('utf-8'))
break
# Remove processed packet from buffer
del self.buffer[:length]
def handle_prelogin(self, packet_data):
"""Handle PreLogin negotiation with correct response format"""
# Build proper PreLogin response
response = bytearray([
# Header
0x04, # Packet type (response)
0x01, # Status (last)
0x00, 0x1B, # Length (27 bytes)
0x00, 0x00, # SPID
0x01, # Packet ID
0x00, # Window
# PreLogin options
0x00, # Option 1 (VERSION): offset
0x00, 0x00, 0x00, # Option 1: length
0x05, # Option 2 (ENCRYPTION): offset
0x00, 0x00, 0x01, # Option 2: length
0x06, # Option 3 (INSTOPT): offset
0x00, 0x00, 0x01, # Option 3: length
0x07, # Option 4 (THREADID): offset
0x00, 0x00, 0x04, # Option 4: length
0x0B, # Option 5 (MARS): offset
0x00, 0x00, 0x01, # Option 5: length
# Data
0x01 # ENCRYPTION: ENCRYPT_ON (0x01)
])
self.client_socket.send(response)
logger.debug("Sent PreLogin response")
def send_login_ack(self):
"""Send proper Login acknowledgement"""
login_ack = bytearray([
# Header
0x04, # Packet type (TABULAR_RESULT)
0x01, # Status (EOM)
0x00, 0x4B, # Length (75 bytes)
0x00, 0x00, # SPID
0x01, # Packet ID
0x00, # Window
# Token type (LOGINACK)
0xAD,
# Length
0x00, 0x42,
# Interface (TDS 7.4)
0x00, 0x00, 0x74, 0x00, 0x64, 0x00, 0x73, 0x00, 0x37, 0x00,
0x2E, 0x00, 0x34, 0x00, 0x00, 0x00,
# TDS version
0x07, 0x00, 0x00, 0x04,
# Program name
0x4D, 0x00, 0x69, 0x00, 0x63, 0x00, 0x72, 0x00, 0x6F, 0x00,
0x73, 0x00, 0x6F, 0x00, 0x66, 0x00, 0x74, 0x00, 0x20, 0x00,
0x53, 0x00, 0x51, 0x00, 0x4C, 0x00, 0x20, 0x00, 0x53, 0x00,
0x65, 0x00, 0x72, 0x00, 0x76, 0x00, 0x65, 0x00, 0x72, 0x00,
# Version (SQL Server 2019)
0x0F, 0x00, 0x07, 0x00, 0x0D, 0x00
])
self.client_socket.send(login_ack)
logger.debug("Sent Login acknowledgement")
'''
def execute_query(self, query):
"""Execute query and send results back to client"""
try:
self.server_cursor.execute(query)
if self.server_cursor.description:
rows = self.server_cursor.fetchall()
response = self.format_results(rows, self.server_cursor.description)
self.client_socket.send(response.encode('utf-8'))
else:
self.client_socket.send(b"Query executed successfully")
except Exception as e:
error_msg = f"Error executing query: {str(e)}"
logger.error(error_msg)
self.client_socket.send(error_msg.encode('utf-8'))
'''
def execute_query(self, query):
"""Execute query and send results back to client"""
try:
logger.debug(f"Forwarding query to SQL Server: {query}")
self.server_cursor.execute(query)
if self.server_cursor.description:
rows = self.server_cursor.fetchall()
response = self.format_results(rows, self.server_cursor.description)
self.client_socket.send(response.encode('utf-8'))
else:
self.client_socket.send(b"Query executed successfully")
except Exception as e:
error_msg = f"Error executing query: {str(e)}"
logger.error(error_msg)
self.client_socket.send(error_msg.encode('utf-8'))
def format_results(self, rows, description):
"""Format query results as tabular text"""
if not rows:
return "No results returned"
# Get column names and widths
columns = [col[0] for col in description]
col_widths = [len(str(col)) for col in columns]
# Calculate column widths
for row in rows:
for i, val in enumerate(row):
col_widths[i] = max(col_widths[i], len(str(val)))
# Build header
header = " | ".join(f"{col:<{width}}" for col, width in zip(columns, col_widths))
separator = "-+-".join("-" * width for width in col_widths)
lines = [header, separator]
# Add rows
for row in rows:
line = " | ".join(f"{str(val):<{width}}" for val, width in zip(row, col_widths))
lines.append(line)
return "\n".join(lines) + "\n"
def handle_client(client_socket):
logger.info(f"New connection from {client_socket.getpeername()}")
handler = TDSPacketHandler()
try:
while True:
# Receive data
data = client_socket.recv(4096)
if not data:
break
# Process TDS packets
handler.process_packets(data, client_socket)
except ConnectionResetError:
logger.info("Client disconnected")
except Exception as e:
logger.error(f"Connection error: {e}")
finally:
if hasattr(handler, 'server_conn') and handler.server_conn:
handler.server_conn.close()
client_socket.close()
logger.info("Connection closed")
def start_proxy_server():
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(('0.0.0.0', 1434))
server_socket.listen(5)
logger.info("Proxy server listening on port 1434...")
try:
while True:
client_socket, addr = server_socket.accept()
logger.info(f"Accepted connection from {addr}")
threading.Thread(
target=handle_client,
args=(client_socket,),
daemon=True
).start()
except KeyboardInterrupt:
logger.info("Shutting down proxy server")
finally:
server_socket.close()
if __name__ == '__main__':
start_proxy_server()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment