Created
August 1, 2025 22:01
-
-
Save jnm2/31707910819560e848b3545ea5fb45d3 to your computer and use it in GitHub Desktop.
SqlServerDiscoverer (IObservable)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Buffers; | |
using System.Buffers.Binary; | |
using System.Collections.Immutable; | |
using System.Net; | |
using System.Net.NetworkInformation; | |
using System.Net.Sockets; | |
using System.Text; | |
/// <summary> | |
/// Discovers SQL Server instances on the local network asynchronously using <see | |
/// href="https://learn.microsoft.com/en-us/openspecs/windows_protocols/mc-sqlr"/>/ | |
/// </summary> | |
public sealed class SqlServerDiscoverer : IObservable<SqlServerInstance> | |
{ | |
public IDisposable Subscribe(IObserver<SqlServerInstance> observer) | |
{ | |
var subscriptionEnded = new CancellationTokenSource(); | |
// If more than one interface can broadcast, you must bind to a particular interface before broadcasting or Windows will discard the packet. | |
foreach (var interfaceAddress in ( | |
from i in NetworkInterface.GetAllNetworkInterfaces() | |
where i.OperationalStatus == OperationalStatus.Up && !i.IsReceiveOnly && i.SupportsMulticast | |
let localV4Address = i.GetIPProperties().UnicastAddresses.Select(a => a.Address).SingleOrDefault(a => a.AddressFamily == AddressFamily.InterNetwork) | |
where localV4Address != null | |
select localV4Address | |
).DefaultIfEmpty(IPAddress.Any)) | |
{ | |
var client = new SsrpClient(new IPEndPoint(interfaceAddress, 0)); | |
client.InstancesDiscovered += (_, instances) => | |
{ | |
if (subscriptionEnded.IsCancellationRequested) return; | |
foreach (var instance in instances) | |
observer.OnNext(instance); | |
}; | |
client.ResolveAsync(subscriptionEnded.Token).ContinueWith(task => | |
{ | |
client.Dispose(); | |
if (task.IsFaulted && !subscriptionEnded.IsCancellationRequested) | |
{ | |
foreach (var exception in task.Exception.InnerExceptions) | |
observer.OnError(exception); | |
} | |
}); | |
} | |
return Disposable.Create(subscriptionEnded.Cancel); | |
} | |
private sealed class SsrpClient : IDisposable | |
{ | |
private const byte CLNT_BCAST_EX = 0x02; | |
private const byte SVR_RESP = 0x05; | |
private readonly Socket socket; | |
private int isDisposed; | |
public event EventHandler<ImmutableArray<SqlServerInstance>>? InstancesDiscovered; | |
public SsrpClient(IPEndPoint localEndPoint) | |
{ | |
socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp) { EnableBroadcast = true }; | |
socket.Bind(localEndPoint); | |
} | |
public void Dispose() | |
{ | |
if (Interlocked.Exchange(ref isDisposed, 1) != 0) return; | |
socket.Dispose(); | |
} | |
public async Task ResolveAsync(CancellationToken cancellationToken) | |
{ | |
var buffer = ArrayPool<byte>.Shared.Rent(4096); | |
try | |
{ | |
buffer[0] = CLNT_BCAST_EX; | |
await socket.SendToAsync(buffer.AsMemory(0, 1), new IPEndPoint(IPAddress.Broadcast, 1434), cancellationToken); | |
while (Volatile.Read(ref isDisposed) == 0) | |
{ | |
var bytesReceived = await socket.ReceiveAsync(buffer, cancellationToken); | |
HandleResponse(buffer.AsSpan(0, bytesReceived)); | |
} | |
} | |
catch (OperationCanceledException) { } | |
ArrayPool<byte>.Shared.Return(buffer); | |
} | |
private void HandleResponse(Span<byte> data) | |
{ | |
var handler = InstancesDiscovered; | |
if (handler == null) return; | |
if (data is not [SVR_RESP, _, _, ..]) return; | |
var responseLength = BinaryPrimitives.ReadUInt16LittleEndian(data[1..]); | |
if (responseLength > data.Length - 3) | |
throw new InvalidDataException("Response did not fit in buffer."); | |
handler.Invoke(this, ParseRespData(data.Slice(3, responseLength))); | |
} | |
private static ImmutableArray<SqlServerInstance> ParseRespData(ReadOnlySpan<byte> respData) | |
{ | |
var instances = ImmutableArray.CreateBuilder<SqlServerInstance>(); | |
var reader = new RespDataReader(respData); | |
while (!reader.IsEndOfData) | |
{ | |
var serverName = Encoding.UTF8.GetString(reader.ReadField("ServerName"u8)); | |
var instanceName = Encoding.UTF8.GetString(reader.ReadField("InstanceName"u8)); | |
_ = reader.ReadField("IsClustered"u8); | |
var version = Version.Parse(Encoding.UTF8.GetString(reader.ReadField("Version"u8))); | |
instances.Add(new SqlServerInstance | |
{ | |
ServerName = serverName, | |
InstanceName = instanceName.Equals("MSSQLSERVER", StringComparison.OrdinalIgnoreCase) ? null : instanceName, | |
Version = version, | |
}); | |
reader.ReadToEndOfInstance(); | |
} | |
return instances.DrainToImmutable(); | |
} | |
private ref struct RespDataReader(ReadOnlySpan<byte> respData) | |
{ | |
private ReadOnlySpan<byte> remaining = respData; | |
public bool IsEndOfData => remaining.IsEmpty; | |
public ReadOnlySpan<byte> ReadField(ReadOnlySpan<byte> expectedName) | |
{ | |
if (!remaining.StartsWith(expectedName) || !remaining[expectedName.Length..].StartsWith((byte)';')) | |
throw new InvalidDataException($"Expected \"{Encoding.UTF8.GetString(expectedName)};\" but found: {Encoding.UTF8.GetString(remaining)}"); | |
remaining = remaining[(expectedName.Length + 1)..]; | |
var fieldEndIndex = remaining.IndexOf((byte)';'); | |
if (fieldEndIndex == -1) | |
throw new InvalidDataException($"Expected semicolon following \"{Encoding.UTF8.GetString(expectedName)};\" but found: {Encoding.UTF8.GetString(remaining)}"); | |
var fieldValue = remaining[..fieldEndIndex]; | |
remaining = remaining[(fieldEndIndex + 1)..]; | |
return fieldValue; | |
} | |
public void ReadToEndOfInstance() | |
{ | |
while (true) | |
{ | |
var nextSemicolonIndex = remaining.IndexOf((byte)';'); | |
if (nextSemicolonIndex == -1) | |
throw new InvalidDataException($"Expected semicolon at end of instance, but found: {Encoding.UTF8.GetString(remaining)}"); | |
remaining = remaining[(nextSemicolonIndex + 1)..]; | |
if (nextSemicolonIndex == 0) | |
{ | |
// The last thing that was read would have been a semicolon, and here we are again at a semicolon | |
// with no field name. This must be the end. | |
break; | |
} | |
// Otherwise, we read past a field name and its semicolon. | |
nextSemicolonIndex = remaining.IndexOf((byte)';'); | |
if (nextSemicolonIndex == -1) | |
throw new InvalidDataException($"Expected field value and semicolon following field name and semicolon, but found: {Encoding.UTF8.GetString(remaining)}"); | |
remaining = remaining[(nextSemicolonIndex + 1)..]; | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment