Last active
October 25, 2015 11:43
-
-
Save GaretJax/124c523a62ba48c9eec1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from OpenSSL import SSL as ssl | |
from zope.interface import implementer | |
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol | |
from twisted.internet import defer | |
from twisted.internet.ssl import CertificateOptions | |
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator | |
from twisted.logger import Logger | |
from .utils import CachingDict | |
class DummyTransport(object): | |
""" | |
Dummy transport ignoring writes and connection drops. | |
""" | |
def write(self, bytes): | |
pass | |
def loseConnection(self): | |
pass | |
class TLSServerNameCallbackHelper(TLSMemoryBIOProtocol, object): | |
""" | |
Fake TLSMemoryBIOProtocol to be used until the client hello is received | |
and the SNI callback can be triggered by the underlying SSL implementation. | |
""" | |
def __init__(self, sniHandler, *args, **kwargs): | |
super(TLSServerNameCallbackHelper, self).__init__(*args, **kwargs) | |
self._receivedBytes = [] | |
self._sniHandlerCalled = False | |
self._sniHandler = sniHandler | |
self._tlsConnection = self._buildDummyConnection() | |
self.transport = DummyTransport() | |
def _buildDummyConnection(self): | |
context = CertificateOptions().getContext() | |
context.set_tlsext_servername_callback(self._executeServernameCallback) | |
context.set_info_callback(self._handover) | |
connection = ssl.Connection(context, None) | |
connection.set_accept_state() | |
return connection | |
def _handover(self, connection, where, ret): | |
if self._sniHandlerCalled: | |
return | |
if where & ssl.SSL_CB_EXIT: | |
self._sniHandlerCalled = True | |
self._gotContext(None, connection) | |
def _gotContext(self, context, connection): | |
connection.shutdown() | |
bytes = ''.join(self._receivedBytes) | |
self._sniHandler.gotContext(bytes, context) | |
def _gotError(self, failure, connection): | |
connection.shutdown() | |
bytes = ''.join(self._receivedBytes) | |
self._sniHandler.gotError(bytes, failure) | |
def _executeServernameCallback(self, connection): | |
assert not self._sniHandlerCalled | |
self._sniHandlerCalled = True | |
d = defer.maybeDeferred( | |
self.factory._connectionCreator.serverContextForSNI, connection) | |
d.addCallback(self._gotContext, connection) | |
d.addErrback(self._gotError, connection) | |
def dataReceived(self, bytes): | |
self._receivedBytes.append(bytes) | |
super(TLSServerNameCallbackHelper, self).dataReceived(bytes) | |
class SNIEnabledTLSMemoryBIOProtocol(TLSMemoryBIOProtocol, object): | |
""" | |
TLSMemoryBIOProtocol first sending the client hello to an helper to trigger | |
the SNI callback returning a deferred, waiting for it to callback and then | |
replaying the client hello on the real implementation, for which we already | |
have a context. | |
""" | |
log = Logger() | |
def _replayHandshake(self, bytes): | |
# Restore the original dataReceived method and replay the received | |
# bytes on the original connection. | |
self.dataReceived = self._originalDataReceived | |
self.dataReceived(bytes) | |
def gotContext(self, bytes, context): | |
if context: | |
self.getHandle().set_context(context) | |
self._replayHandshake(bytes) | |
def gotError(self, bytes, failure): | |
self.log.error('failed to build context', failure=failure) | |
self.loseConnection() | |
def makeConnection(self, transport): | |
# Hook up the dataReceived method from the handshake helper until | |
# the client hello is received, the SSL implementation parsed the | |
# SNI extension, and the deferred returned by the SNI calls back. | |
handshakeHelper = TLSServerNameCallbackHelper( | |
self, self.factory, self.wrappedProtocol, self._connectWrapped) | |
self._originalDataReceived = self.dataReceived | |
self.dataReceived = handshakeHelper.dataReceived | |
super(SNIEnabledTLSMemoryBIOProtocol, self).makeConnection(transport) | |
class SNIEnabledTLSMemoryBIOFactory(TLSMemoryBIOFactory): | |
protocol = SNIEnabledTLSMemoryBIOProtocol | |
class ISNIEnabledConnectionCreator(IOpenSSLServerConnectionCreator): | |
def serverContextForSNI(self, connection): | |
""" | |
Called when the server name indication is received by the server | |
(tlsext_servername_callback of pyOpenSSL). | |
This method can return `None` to not alter the connection context, | |
a new context instance to be used for the connection, or a deferred | |
with any of the previous two return values. | |
The returned context will be set as the context of the connection. | |
Any context set directly on the `connection` argument (i.e. by using | |
`Connection.set_context`) will be lost. | |
""" | |
class SNIEnabledTLSEndpoint(object): | |
""" | |
TLS endpoint with support for returning deferreds from the server name | |
indication callback. | |
""" | |
def __init__(self, endpoint, contextFactory): | |
assert ISNIEnabledConnectionCreator.providedBy(contextFactory) | |
self.endpoint = endpoint | |
self.contextFactory = contextFactory | |
def listen(self, factory): | |
return self.endpoint.listen(SNIEnabledTLSMemoryBIOFactory( | |
self.contextFactory, False, factory | |
)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example usage | |
from zope.interface import implementer | |
from twisted.internet import reactor, endpoints | |
from twisted.web import static, server | |
@implementer(ISNIEnabledConnectionCreator) | |
class SNICallbackSSLFactory(object): | |
def __init__(self, certificate_options): | |
self.certificate_options = certificate_options | |
def _makeContext(self): | |
# NOTE/TODO: Somehow the connections are picky about sharing contexts | |
# between them. This might no be an issue when different connection | |
# instances are created for the same session, but it is here because | |
# we reuse the same context connections initialized with exactly the | |
# same client hello. | |
self.certificate_options._context = None | |
return self.certificate_options.getContext() | |
def serverContextForSNI(self, connection): | |
hostname = connection.get_servername() | |
def build(d): | |
context = self._makeContext() | |
context.use_privatekey_file('certs/{}/key.pem'.format(hostname)) | |
context.use_certificate_file('certs/{}/cert.pem'.format(hostname)) | |
d.callback(context) | |
d = defer.Deferred() | |
reactor.callLater(1, build, d) | |
return d | |
def serverConnectionForTLS(self, tlsProtocol): | |
return ssl.Connection(self._makeContext(), None) | |
server_factory = server.Site(static.Data('Hello world!', 'text/plain')) | |
ssl_context_factory = SNICallbackSSLFactory(CertificateOptions()) | |
tcp_endpoint = endpoints.TCP4ServerEndpoint(reactor, 443) | |
tls_endpoint = SNIEnabledTLSEndpoint(tcp_endpoint, ssl_context_factory) | |
tls_endpoint.listen(server_factory) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment