|
import asyncio |
|
from nats.aio.client import Client as NATS, DEFAULT_BUFFER_SIZE |
|
import time |
|
import ssl |
|
from nats.aio.errors import ErrNoServers |
|
|
|
|
|
class TLSOnlyClient(NATS): |
|
""" A TLS only nats Client. |
|
By default, the initial connection is done without TLS |
|
and then the connection gets upgraded to TLS based on servers configuration. |
|
|
|
@see https://github.com/nats-io/nats-server/issues/291 |
|
""" |
|
@asyncio.coroutine |
|
def _select_next_server(self): |
|
""" |
|
Looks up in the server pool for an available server |
|
and attempts to connect. |
|
|
|
The only change was made when calling asyncio.open_connection(): |
|
we create and pass the `sc` param. |
|
""" |
|
|
|
while True: |
|
if len(self._server_pool) == 0: |
|
self._current_server = None |
|
raise ErrNoServers |
|
|
|
now = time.monotonic() |
|
s = self._server_pool.pop(0) |
|
if self.options["max_reconnect_attempts"] > 0: |
|
if s.reconnects > self.options["max_reconnect_attempts"]: |
|
# Discard server since already tried to reconnect too many times |
|
continue |
|
|
|
# Not yet exceeded max_reconnect_attempts so can still use |
|
# this server in the future. |
|
self._server_pool.append(s) |
|
if s.last_attempt is not None and now < s.last_attempt + self.options[ |
|
"reconnect_time_wait"]: |
|
# Backoff connecting to server if we attempted recently. |
|
yield from asyncio.sleep( |
|
self.options["reconnect_time_wait"], loop=self._loop |
|
) |
|
try: |
|
# added next lines |
|
ssl_context = None |
|
if s.uri.scheme == 'tls': |
|
ssl_context = ssl.create_default_context() |
|
if "tls" in self.options: |
|
ssl_context = self.options.get('tls') |
|
# end added next lines |
|
s.last_attempt = time.monotonic() |
|
r, w = yield from asyncio.open_connection( |
|
s.uri.hostname, |
|
s.uri.port, |
|
loop=self._loop, |
|
ssl=ssl_context, # added this param |
|
limit=DEFAULT_BUFFER_SIZE |
|
) |
|
self._current_server = s |
|
self._bare_io_reader = self._io_reader = r |
|
self._bare_io_writer = self._io_writer = w |
|
break |
|
except Exception as e: |
|
s.last_attempt = time.monotonic() |
|
s.reconnects += 1 |
|
|
|
self._err = e |
|
if self._error_cb is not None: |
|
yield from self._error_cb(e) |
|
continue |