Created
January 11, 2021 20:26
-
-
Save edoakes/0e014d79d3117fe3fb60ec4edd758cc4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from asyncio.futures import Future | |
from collections import defaultdict | |
from typing import Dict, Any, List, Optional, Set, Tuple | |
import ray | |
import ray.cloudpickle as pickle | |
from ray.actor import ActorHandle | |
from ray.serve.async_goal_manager import AsyncGoalManager | |
from ray.serve.backend_worker import create_backend_replica | |
from ray.serve.common import ( | |
BackendInfo, | |
BackendTag, | |
Duration, | |
GoalId, | |
ReplicaTag, | |
) | |
from ray.serve.config import BackendConfig, ReplicaConfig | |
from ray.serve.constants import LongPollKey | |
from ray.serve.exceptions import RayServeException | |
from ray.serve.kv_store import RayInternalKVStore | |
from ray.serve.long_poll import LongPollHost | |
from ray.serve.utils import (format_actor_name, get_random_letters, logger, | |
try_schedule_resources_on_nodes) | |
CHECKPOINT_KEY = "serve-backend-state-checkpoint" | |
# Feature flag for controller resource checking. If true, controller will | |
# error if the desired replicas exceed current resource availability. | |
_RESOURCE_CHECK_ENABLED = True | |
class BackendState: | |
"""Manages all state for backends in the system. | |
This class is *not* thread safe, so any state-modifying methods should be | |
called with a lock held. | |
""" | |
def __init__(self, controller_name: str, detached: bool, | |
kv_store: RayInternalKVStore, long_poll_host: LongPollHost, | |
goal_manager: AsyncGoalManager): | |
self._controller_name = controller_name | |
self._detached = detached | |
self._kv_store = kv_store | |
self._long_poll_host = long_poll_host | |
self._goal_manager = goal_manager | |
self._replicas = Dict[BackendTag, Dict[ReplicaState, List[BackendReplica]]] | |
self._backend_metadata: Dict[BackendTag, BackendInfo] = dict() | |
self._target_replicas: Dict[BackendTag, int] = dict() | |
def create_backend(): | |
self._backend_metadata[backend] = backend_info | |
self._target_replicas = backend_info.num_replicas # can be changed by autoscaling | |
return | |
def update(): | |
for backend, target in self._target_replicas: | |
curr_num_replicas = count(should_start) + count(pending_start) + count(running) | |
if curr_num_replicas < target: | |
# start some | |
elif curr_num_replicas > target: | |
# stop some | |
checkpoint_if_needed() | |
for replica in replicas: | |
# only handles state transitions | |
for replica in SHOULD_START: | |
start them | |
for replica in SHOULD_STOP: | |
stop them | |
for replica in | |
# check pending starts | |
# check pending stops | |
class BackendReplica: | |
def __init__(self, actor_name): | |
self._actor_name = actor_name | |
self._state = ReplicaState.NEW | |
def mark_to_start(self): | |
self._state = ReplicaState.SHOULD_START | |
def start(self): | |
# Do all the stuff to start the actor. | |
self._state = ReplicaState.PENDING | |
def mark_to_stop(self): | |
self._state = ReplicaState.SHOULD_STOP | |
def stop(self): | |
# actually kill the actor | |
self._state = ReplicaState.STOPPED | |
# Non-checkpointed state. | |
# | |
# Pending the actor to start up, waiting for this ObjectRef to be done. | |
# Replicas in the "pending" state. | |
# | |
self.currently_starting_replicas: Dict[asyncio.Future, Tuple[ | |
BackendTag, ReplicaTag, ActorHandle]] = dict() | |
self.currently_stopping_replicas: Dict[asyncio.Future, Tuple[ | |
BackendTag, ReplicaTag]] = dict() | |
# Checkpointed state. | |
# | |
# Metadata. | |
# | |
self.backends: Dict[BackendTag, BackendInfo] = dict() | |
# | |
# Replicas in the "running" state. | |
# | |
self.backend_replicas: Dict[BackendTag, Dict[ | |
ReplicaTag, ActorHandle]] = defaultdict(dict) | |
# | |
# Quirk of checkpointing -- "intent" to start a replica. | |
# | |
self.backend_replicas_to_start: Dict[BackendTag, List[ | |
ReplicaTag]] = defaultdict(list) | |
# | |
# Quirk of checkpointing -- "intent" to stop a replica. | |
# | |
self.backend_replicas_to_stop: Dict[BackendTag, List[Tuple[ | |
ReplicaTag, Duration]]] = defaultdict(list) | |
self.backends_to_remove: List[BackendTag] = list() | |
self.backend_goals: Dict[BackendTag, GoalId] = dict() | |
checkpoint = self._kv_store.get(CHECKPOINT_KEY) | |
if checkpoint is not None: | |
(self.backends, self.backend_replicas, self.backend_goals, | |
self.backend_replicas_to_start, self.backend_replicas_to_stop, | |
self.backend_to_remove, | |
pending_goal_ids) = pickle.loads(checkpoint) | |
for goal_id in pending_goal_ids: | |
self._goal_manager.create_goal(goal_id) | |
# Fetch actor handles for all backend replicas in the system. | |
# All of these backend_replicas are guaranteed to already exist | |
# because they would not be written to a checkpoint in | |
# self.backend_replicas until they were created. | |
for backend_tag, replica_dict in self.backend_replicas.items(): | |
for replica_tag in replica_dict.keys(): | |
replica_name = format_actor_name(replica_tag, | |
self._controller_name) | |
self.backend_replicas[backend_tag][ | |
replica_tag] = ray.get_actor(replica_name) | |
self._notify_backend_configs_changed() | |
self._notify_replica_handles_changed() | |
def _checkpoint(self) -> None: | |
self._kv_store.put( | |
CHECKPOINT_KEY, | |
pickle.dumps( | |
(self.backends, self.backend_replicas, self.backend_goals, | |
self.backend_replicas_to_start, self.backend_replicas_to_stop, | |
self.backends_to_remove, | |
self._goal_manager.get_pending_goal_ids()))) | |
def _notify_backend_configs_changed(self) -> None: | |
self._long_poll_host.notify_changed(LongPollKey.BACKEND_CONFIGS, | |
self.get_backend_configs()) | |
def _notify_replica_handles_changed(self) -> None: | |
self._long_poll_host.notify_changed( | |
LongPollKey.REPLICA_HANDLES, { | |
backend_tag: list(replica_dict.values()) | |
for backend_tag, replica_dict in self.backend_replicas.items() | |
}) | |
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: | |
return { | |
tag: info.backend_config | |
for tag, info in self.backends.items() | |
} | |
def get_replica_handles( | |
self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: | |
return self.backend_replicas | |
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: | |
return self.backends.get(backend_tag) | |
def _set_backend_goal(self, backend_tag: BackendTag, | |
backend_info: BackendInfo) -> None: | |
existing_goal_id = self.backend_goals.get(backend_tag) | |
new_goal_id = self._goal_manager.create_goal() | |
if backend_info is not None: | |
self.backends[backend_tag] = backend_info | |
self.backend_goals[backend_tag] = new_goal_id | |
return new_goal_id, existing_goal_id | |
def create_backend(self, backend_tag: BackendTag, | |
backend_config: BackendConfig, | |
replica_config: ReplicaConfig) -> Optional[GoalId]: | |
# Ensures this method is idempotent. | |
backend_info = self.backends.get(backend_tag) | |
if backend_info is not None: | |
if (backend_info.backend_config == backend_config | |
and backend_info.replica_config == replica_config): | |
return None | |
backend_replica = create_backend_replica(replica_config.func_or_class) | |
# Save creator that starts replicas, the arguments to be passed in, | |
# and the configuration for the backends. | |
backend_info = BackendInfo( | |
worker_class=backend_replica, | |
backend_config=backend_config, | |
replica_config=replica_config) | |
new_goal_id, existing_goal_id = self._set_backend_goal( | |
backend_tag, backend_info) | |
try: | |
self.scale_backend_replicas(backend_tag, | |
backend_config.num_replicas) | |
except RayServeException as e: | |
del self.backends[backend_tag] | |
raise e | |
# NOTE(edoakes): we must write a checkpoint before starting new | |
# or pushing the updated config to avoid inconsistent state if we | |
# crash while making the change. | |
self._checkpoint() | |
self._notify_backend_configs_changed() | |
if existing_goal_id is not None: | |
self._goal_manager.complete_goal(existing_goal_id) | |
return new_goal_id | |
def delete_backend(self, backend_tag: BackendTag, | |
force_kill: bool = False) -> Optional[GoalId]: | |
# This method must be idempotent. We should validate that the | |
# specified backend exists on the client. | |
if backend_tag not in self.backends: | |
return None | |
# Scale its replicas down to 0. | |
self.scale_backend_replicas(backend_tag, 0, force_kill) | |
# Remove the backend's metadata. | |
del self.backends[backend_tag] | |
# Add the intention to remove the backend from the routers. | |
self.backends_to_remove.append(backend_tag) | |
new_goal_id, existing_goal_id = self._set_backend_goal( | |
backend_tag, None) | |
self._checkpoint() | |
if existing_goal_id is not None: | |
self._goal_manager.complete_goal(existing_goal_id) | |
return new_goal_id | |
def update_backend_config(self, backend_tag: BackendTag, | |
config_options: BackendConfig): | |
if backend_tag not in self.backends: | |
raise ValueError(f"Backend {backend_tag} is not registered") | |
stored_backend_config = self.backends[backend_tag].backend_config | |
updated_config = stored_backend_config.copy( | |
update=config_options.dict(exclude_unset=True)) | |
updated_config._validate_complete() | |
self.backends[backend_tag].backend_config = updated_config | |
new_goal_id, existing_goal_id = self._set_backend_goal( | |
backend_tag, self.backends[backend_tag]) | |
# Scale the replicas with the new configuration. | |
self.scale_backend_replicas(backend_tag, updated_config.num_replicas) | |
# NOTE(edoakes): we must write a checkpoint before pushing the | |
# update to avoid inconsistent state if we crash after pushing the | |
# update. | |
self._checkpoint() | |
if existing_goal_id is not None: | |
self._goal_manager.complete_goal(existing_goal_id) | |
# Inform the routers and backend replicas about config changes. | |
# TODO(edoakes): this should only happen if we change something other | |
# than num_replicas. | |
self._notify_backend_configs_changed() | |
return new_goal_id | |
def _start_backend_replica(self, backend_tag: BackendTag, | |
replica_tag: ReplicaTag) -> ActorHandle: | |
"""Start a replica and return its actor handle. | |
Checks if the named actor already exists before starting a new one. | |
Assumes that the backend configuration is already in the Goal State. | |
""" | |
# NOTE(edoakes): the replicas may already be created if we | |
# failed after creating them but before writing a | |
# checkpoint. | |
replica_name = format_actor_name(replica_tag, self._controller_name) | |
try: | |
replica_handle = ray.get_actor(replica_name) | |
except ValueError: | |
logger.debug("Starting replica '{}' for backend '{}'.".format( | |
replica_tag, backend_tag)) | |
backend_info = self.get_backend(backend_tag) | |
replica_handle = ray.remote(backend_info.worker_class).options( | |
name=replica_name, | |
lifetime="detached" if self._detached else None, | |
max_restarts=-1, | |
max_task_retries=-1, | |
**backend_info.replica_config.ray_actor_options).remote( | |
backend_tag, replica_tag, | |
backend_info.replica_config.actor_init_args, | |
backend_info.backend_config, self._controller_name) | |
return replica_handle | |
def scale_backend_replicas( | |
self, | |
backend_tag: BackendTag, | |
num_replicas: int, | |
force_kill: bool = False, | |
) -> None: | |
"""Scale the given backend to the number of replicas. | |
NOTE: this does not actually start or stop the replicas, but instead | |
adds the intention to start/stop them to self.backend_replicas_to_start | |
and self.backend_replicas_to_stop. The caller is responsible for then | |
first writing a checkpoint and then actually starting/stopping the | |
intended replicas. This avoids inconsistencies with starting/stopping a | |
replica and then crashing before writing a checkpoint. | |
""" | |
logger.debug("Scaling backend '{}' to {} replicas".format( | |
backend_tag, num_replicas)) | |
assert (backend_tag in self.backends | |
), "Backend {} is not registered.".format(backend_tag) | |
assert num_replicas >= 0, ("Number of replicas must be" | |
" greater than or equal to 0.") | |
current_num_replicas = len(self.backend_replicas[backend_tag]) | |
delta_num_replicas = num_replicas - current_num_replicas | |
backend_info: BackendInfo = self.backends[backend_tag] | |
if delta_num_replicas > 0: | |
can_schedule = try_schedule_resources_on_nodes(requirements=[ | |
backend_info.replica_config.resource_dict | |
for _ in range(delta_num_replicas) | |
]) | |
if _RESOURCE_CHECK_ENABLED and not all(can_schedule): | |
num_possible = sum(can_schedule) | |
raise RayServeException( | |
"Cannot scale backend {} to {} replicas. Ray Serve tried " | |
"to add {} replicas but the resources only allows {} " | |
"to be added. To fix this, consider scaling to replica to " | |
"{} or add more resources to the cluster. You can check " | |
"avaiable resources with ray.nodes().".format( | |
backend_tag, num_replicas, delta_num_replicas, | |
num_possible, current_num_replicas + num_possible)) | |
logger.debug("Adding {} replicas to backend {}".format( | |
delta_num_replicas, backend_tag)) | |
for _ in range(delta_num_replicas): | |
replica_tag = "{}#{}".format(backend_tag, get_random_letters()) | |
self.backend_replicas_to_start[backend_tag].append(replica_tag) | |
elif delta_num_replicas < 0: | |
logger.debug("Removing {} replicas from backend '{}'".format( | |
-delta_num_replicas, backend_tag)) | |
assert len( | |
self.backend_replicas[backend_tag]) >= delta_num_replicas | |
replicas_copy = self.backend_replicas.copy() | |
for _ in range(-delta_num_replicas): | |
replica_tag, _ = replicas_copy[backend_tag].popitem() | |
graceful_timeout_s = (backend_info.backend_config. | |
experimental_graceful_shutdown_timeout_s) | |
if force_kill: | |
graceful_timeout_s = 0 | |
self.backend_replicas_to_stop[backend_tag].append(( | |
replica_tag, | |
graceful_timeout_s, | |
)) | |
def _start_pending_replicas(self): | |
for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ | |
items(): | |
for replica_tag in replicas_to_create: | |
replica_handle = self._start_backend_replica( | |
backend_tag, replica_tag) | |
ready_future = replica_handle.ready.remote().as_future() | |
self.currently_starting_replicas[ready_future] = ( | |
backend_tag, replica_tag, replica_handle) | |
def _stop_pending_replicas(self): | |
for backend_tag, replicas_to_stop in ( | |
self.backend_replicas_to_stop.items()): | |
for replica_tag, shutdown_timeout in replicas_to_stop: | |
replica_name = format_actor_name(replica_tag, | |
self._controller_name) | |
async def kill_actor(replica_name_to_use): | |
# NOTE: the replicas may already be stopped if we failed | |
# after stopping them but before writing a checkpoint. | |
try: | |
replica = ray.get_actor(replica_name_to_use) | |
except ValueError: | |
return | |
try: | |
await asyncio.wait_for( | |
replica.drain_pending_queries.remote(), | |
timeout=shutdown_timeout) | |
except asyncio.TimeoutError: | |
# Graceful period passed, kill it forcefully. | |
logger.debug( | |
f"{replica_name_to_use} did not shutdown after " | |
f"{shutdown_timeout}s, killing.") | |
finally: | |
ray.kill(replica, no_restart=True) | |
self.currently_stopping_replicas[asyncio.ensure_future( | |
kill_actor(replica_name))] = (backend_tag, replica_tag) | |
async def _check_currently_starting_replicas(self) -> int: | |
"""Returns the number of pending replicas waiting to start""" | |
in_flight: Set[Future[Any]] = set() | |
if self.currently_starting_replicas: | |
done, in_flight = await asyncio.wait( | |
list(self.currently_starting_replicas.keys()), timeout=0) | |
for fut in done: | |
(backend_tag, replica_tag, | |
replica_handle) = self.currently_starting_replicas.pop(fut) | |
self.backend_replicas[backend_tag][ | |
replica_tag] = replica_handle | |
backend = self.backend_replicas_to_start.get(backend_tag) | |
if backend: | |
try: | |
backend.remove(replica_tag) | |
except ValueError: | |
pass | |
if len(backend) == 0: | |
del self.backend_replicas_to_start[backend_tag] | |
async def _check_currently_stopping_replicas(self) -> int: | |
"""Returns the number of replicas waiting to stop""" | |
in_flight: Set[Future[Any]] = set() | |
if self.currently_stopping_replicas: | |
done_stopping, in_flight = await asyncio.wait( | |
list(self.currently_stopping_replicas.keys()), timeout=0) | |
for fut in done_stopping: | |
(backend_tag, | |
replica_tag) = self.currently_stopping_replicas.pop(fut) | |
backend_to_stop = self.backend_replicas_to_stop.get( | |
backend_tag) | |
if backend_to_stop: | |
try: | |
backend_to_stop.remove(replica_tag) | |
except ValueError: | |
pass | |
if len(backend_to_stop) == 0: | |
del self.backend_replicas_to_stop[backend_tag] | |
backend = self.backend_replicas.get(backend_tag) | |
if backend: | |
try: | |
del backend[replica_tag] | |
except KeyError: | |
pass | |
if len(self.backend_replicas[backend_tag]) == 0: | |
del self.backend_replicas[backend_tag] | |
def _completed_goals(self) -> List[GoalId]: | |
completed_goals = [] | |
all_tags = set(self.backend_replicas.keys()).union( | |
set(self.backends.keys())) | |
for backend_tag in all_tags: | |
desired_info = self.backends.get(backend_tag) | |
existing_info = self.backend_replicas.get(backend_tag) | |
# Check for deleting | |
if (not desired_info or | |
desired_info.backend_config.num_replicas == 0) and \ | |
(not existing_info or len(existing_info) == 0): | |
completed_goals.append(self.backend_goals[backend_tag]) | |
# Check for a non-zero number of backends | |
if desired_info and existing_info and desired_info.backend_config.\ | |
num_replicas == len(existing_info): | |
completed_goals.append(self.backend_goals[backend_tag]) | |
return completed_goals | |
async def update(self) -> bool: | |
for goal_id in self._completed_goals(): | |
self._goal_manager.complete_goal(goal_id) | |
self._start_pending_replicas() | |
self._stop_pending_replicas() | |
num_starting = len(self.currently_starting_replicas) | |
num_stopping = len(self.currently_stopping_replicas) | |
await self._check_currently_starting_replicas() | |
await self._check_currently_stopping_replicas() | |
if (len(self.currently_starting_replicas) != num_starting) or \ | |
(len(self.currently_stopping_replicas) != num_stopping): | |
self._checkpoint() | |
self._notify_replica_handles_changed() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment