Created
March 21, 2025 15:19
-
-
Save wrouesnel/5589ccec92b035b0f1287d6ba289a7ef to your computer and use it in GitHub Desktop.
Wrapper script for launching an "on-demand" ephemeral EC2 instance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Wrapper script to launch an ephemeral EC2 instance | |
import os | |
import time | |
import sys | |
import argparse | |
import boto3 | |
import logging | |
import socket | |
import getpass | |
import random | |
import subprocess | |
import signal | |
import queue | |
import threading | |
import base64 | |
import json | |
import botocore | |
import signal | |
rand = random.SystemRandom() | |
from typing import Mapping | |
logHandler = logging.StreamHandler(sys.stderr) | |
formatter = logging.Formatter("%(asctime)s:%(name)s:%(funcName)s:%(pathname)s:%(levelname)s:%(lineno)d: %(message)s") | |
logHandler.setFormatter(formatter) | |
logging.basicConfig( | |
handlers=[logHandler], | |
level=logging.INFO, | |
) | |
logger = logging.getLogger(__name__) | |
termination_queue = queue.Queue() | |
# Setup a graceful signal handler now. | |
def signal_handler(signum, frm): | |
logging.info(f"Signal caught: {signal.strsignal(sig)}") | |
termination_queue.put(signum) | |
for sig in (signal.Signals.SIGINT, signal.Signals.SIGTERM, signal.Signals.SIGQUIT, signal.Signals.SIGHUP): | |
termination_queue.put(signum) | |
class CaughtSignal(Exception): | |
"""Exception raised when a signal is caught and application should shutdown""" | |
def __init__(signalnum): | |
super().__init__(f"Shutdown due to signal: {signal.strsignal(signalnum)}") | |
NETWORK_STACK = "ocd-vpc-shared-master-cf" | |
class KeyValueJsonAction(argparse.Action): | |
""" | |
Argparse action for generating dictionaries from key-value args with JSON support | |
""" | |
def __call__(self, parser, namespace, values, option_string=None): | |
# Convert value back into an Enum | |
result = {} | |
for entry in values: | |
key, value = entry.split("=", 1) | |
# Try and load value as a JSON object. If that fails, treat it as a string. | |
try: | |
parsed_value = json.loads(value) | |
except json.JSONDecodeError: | |
parsed_value = value | |
result[key] = parsed_value | |
setattr(namespace, self.dest, result) | |
def boto_results(key, fn, *args, **kwargs): | |
result = [] | |
next_token = None | |
while True: | |
r = fn(*args, **kwargs) | |
next_token = r.get("NextToken") | |
value = r[key] | |
if isinstance(value, (str, bytes)): | |
result.append(value) | |
elif isinstance(value, Mapping): | |
result.append(value) | |
else: | |
result.extend(value) | |
if next_token is None: | |
break | |
kwargs.update({"NextToken": next_token}) | |
return result | |
def unique_id() -> str: | |
"""Generate a lower case, short unique identifier""" | |
# Get 128 random bits | |
value = bytearray(random.getrandbits(8) for _ in range(16)) | |
result = base64.b32encode(value).decode("utf8").rstrip("=") | |
return result.lower() | |
USER_DATA = """#!/bin/bash -x | |
sudo -i | |
_build_start=$(date +%s) | |
export HOME="/root" | |
export LOG_FILE="$HOME/userdata.log" | |
mkdir -p $(dirname "$LOG_FILE") | |
touch "$LOG_FILE" | |
chmod 600 "$LOG_FILE" | |
# Close standard output file descriptor | |
exec 1<&- | |
# Close standard error file descriptor | |
exec 2<&- | |
# Open standard output as $LOG_FILE file for read and write. | |
exec 1<>$LOG_FILE | |
# Redirect standard error to standard output | |
exec 2>&1 | |
export PATH="/opt/aws/bin/:$PATH" | |
export AWS_REGION="$(cloud-init query region)" | |
export AWS_DEFAULT_REGION="$(cloud-init query region)" | |
function fatal() {{ | |
# Fail the startup. | |
echo "$*" 1>&2 | |
# Force the instance to die quickly. | |
shutdown -h now | |
exit 1 | |
}} | |
export AWS_INSTANCE_ID="$(cloud-init query instance_id)" | |
if [ $? != 0 ]; then | |
fatal "Failed to get instance-id" | |
fi | |
cd "$HOME" || fatal "Could not change to the "$HOME" directory in userinit" | |
if ! adduser "{username}"; then | |
fatal "Could not create user" | |
fi | |
mkdir -p "/home/{username}/.ssh" | |
chmod 700 "/home/{username}/.ssh" | |
echo "{id_rsa}" >> "/home/{username}/.ssh/authorized_keys" | |
chmod 600 "/home/{username}/.ssh/authorized_keys" | |
chown -R {username}:{username} /home/{username}/.ssh | |
echo "{username} ALL=(ALL) ALL" >> /etc/sudoers | |
echo "Build Finished: $(( $(date +%s) - _build_start )) seconds" | |
""" | |
def cmd_wrap(cmd): | |
env = os.environ.copy() | |
p = subprocess.Popen( | |
cmd, env=env, cwd=os.path.realpath(os.curdir), stdin=sys.stdin, stderr=sys.stderr, stdout=sys.stdout | |
) | |
# Attach all signals and forward them to the subprocess | |
def sighandler(signum, stack): | |
p.send_signal(signum) | |
for i in [x for x in dir(signal) if x.startswith("SIG")]: | |
try: | |
signum = getattr(signal, i) | |
signal.signal(signum, sighandler) | |
except (OSError, RuntimeError, ValueError) as m: # OSError for Python3, RuntimeError for 2 | |
# print("Skipping {}".format(i)) | |
pass | |
p.wait() | |
return p.returncode | |
def main(argv): | |
parser = argparse.ArgumentParser( | |
description="Launch an EC2 instance for updating the S3 instance", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument("--instance-size", default="t2.small", help="AWS instance size to launch instance as") | |
parser.add_argument("--ssm-timeout", default=180, help="Timeout for attempts to make an SSM connection") | |
parser.add_argument("--wait-before-teardown", action="store_true", help="Pause before tearing down resources") | |
parser.add_argument("--subnet-group", default="Public", help="Subnet group to select from") | |
parser.add_argument( | |
"--iam-permissions", | |
action=KeyValueJsonAction, | |
nargs="*", | |
help="Specify IAM permissions as <resource>=<permissions>. Allow is implied, wildcards allowed, multiple permissions may be specified in JSON form.", | |
) | |
parser.add_argument( | |
"--s3-bucket-allow", type=str, nargs="*", help="List of S3 buckets to grant full permissions too" | |
) | |
parser.add_argument("--disk-size", type=int, default=100, help="Disk size in gigabytes of the root disk") | |
args = parser.parse_args(argv) | |
cf = boto3.client("cloudformation") | |
ec2 = boto3.client("ec2") | |
ssm = boto3.client("ssm") | |
iam = boto3.client("iam") | |
id_rsa = open(os.path.join(os.environ["HOME"], ".ssh", "id_rsa.pub"), "rt").read().strip() | |
logging.info("Getting latest Amazon Image ID") | |
ami = ssm.get_parameter(Name="/aws/service/ami-amazon-linux-latest/amzn2-ami-hvm-x86_64-gp2")["Parameter"]["Value"] | |
logging.info("Launching using AMI: %s", ami) | |
network_stack = boto_results("Stacks", cf.describe_stacks, StackName=NETWORK_STACK)[0] | |
network_outputs = {e["OutputKey"]: e["OutputValue"] for e in network_stack["Outputs"]} | |
# TODO: scan the subnets in the target account, and find one which matches our SSM requirements | |
subnet_source = rand.choice([k for k in network_outputs if k.startswith(f"o{args.subnet_group}Subnet")]) | |
username = getpass.getuser() | |
# Functions run in reverse order | |
teardown_functions = [] | |
def teardown(fn, *args, **kwargs): | |
teardown_functions.append((fn, args, kwargs)) | |
def do_teardown(): | |
for fn, args, kwargs in reversed(teardown_functions): | |
logging.info(f"Teardown function: {fn.__name__}, {args}, {kwargs}") | |
fn(*args, **kwargs) | |
identifier = unique_id() | |
instance_name = f"e-{socket.gethostname()}-{username}-{identifier}" | |
instance_profile_name = f"e-{identifier}-InstanceProfile" | |
instance_role_name = f"e-{identifier}-InstanceRole" | |
instance_inline_policy_name = f"e-{identifier}-InlinePolicy" | |
instance_role_assume_role_policy = { | |
"Version": "2012-10-17", | |
"Statement": { | |
"Effect": "Allow", | |
"Principal": {"Service": ["ec2.amazonaws.com"]}, | |
"Action": "sts:AssumeRole", | |
}, | |
} | |
managed_role_arns = [ | |
"arn:aws:iam::aws:policy/CloudWatchAgentServerPolicy", | |
"arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore", | |
# "arn:aws:iam::aws:policy/service-role/AmazonSSMAutomationRole", | |
"arn:aws:iam::aws:policy/EC2InstanceProfileForImageBuilder", | |
] | |
instance_inline_policy = { | |
"Version": "2012-10-17", | |
"Statement": [ | |
{ | |
"Effect": "Allow", | |
"Action": [ | |
"logs:CreateLogGroup", | |
"logs:CreateLogStream", | |
"logs:PutLogEvents", | |
"logs:DescribeLogStreams", | |
], | |
"Resource": "arn:aws:logs:*:*:*", | |
}, | |
{ | |
"Effect": "Allow", | |
"Action": [ | |
"ssm:StartAssociationsOnce", | |
"ssm:CreateAssociation", | |
"ssm:CreateAssociationBatch", | |
"ssm:UpdateAssociation", | |
], | |
"Resource": "*", | |
}, | |
], | |
} | |
# Parse the command line policy requests | |
if args.iam_permissions is not None: | |
for resource, iam_action in args.iam_permissions.items(): | |
statement = { | |
"Effect": "Allow", | |
"Action": iam_action if isinstance(iam_action, str) else iam_action, | |
"Resource": resource, | |
} | |
instance_inline_policy["Statement"].append(statement) | |
if (args.s3_bucket_allow is not None) and (len(args.s3_bucket_allow) > 0): | |
statement = { | |
"Effect": "Allow", | |
"Action": "s3:*", | |
"Resource": [f"arn:aws:s3:::{bucket}" for bucket in args.s3_bucket_allow], | |
} | |
instance_inline_policy["Statement"].append(statement) | |
statement = { | |
"Effect": "Allow", | |
"Action": "s3:*", | |
"Resource": [f"arn:aws:s3:::{bucket}/*" for bucket in args.s3_bucket_allow], | |
} | |
instance_inline_policy["Statement"].append(statement) | |
try: | |
logging.info("Creating role: %s", instance_role_name) | |
instance_role_arn = iam.create_role( | |
Path="/service/", | |
RoleName=instance_role_name, | |
Description="Role for temporary instance", | |
Tags=[{"Key": "instance-name", "Value": instance_name}], | |
AssumeRolePolicyDocument=json.dumps(instance_role_assume_role_policy), | |
)["Role"]["Arn"] | |
teardown(iam.delete_role, RoleName=instance_role_name) | |
# Attach managed roles | |
for managed_role_arn in managed_role_arns: | |
logging.info("Attaching managed roles: %s", managed_role_arn) | |
iam.attach_role_policy(RoleName=instance_role_name, PolicyArn=managed_role_arn) | |
teardown(iam.detach_role_policy, RoleName=instance_role_name, PolicyArn=managed_role_arn) | |
logging.info("Putting inline policy: %s", instance_inline_policy_name) | |
iam.put_role_policy( | |
RoleName=instance_role_name, | |
PolicyName=instance_inline_policy_name, | |
PolicyDocument=json.dumps(instance_inline_policy), | |
) | |
teardown(iam.delete_role_policy, PolicyName=instance_inline_policy_name, RoleName=instance_role_name) | |
logging.info("Setting up an instance profile for the ephemeral instance: %s", instance_profile_name) | |
iam.create_instance_profile(InstanceProfileName=instance_profile_name, Path="/service/") | |
teardown(iam.delete_instance_profile, InstanceProfileName=instance_profile_name) | |
logging.info("Add role to instance profile: %s <- %s", instance_profile_name, instance_role_name) | |
iam.add_role_to_instance_profile(InstanceProfileName=instance_profile_name, RoleName=instance_role_name) | |
teardown( | |
iam.remove_role_from_instance_profile, | |
InstanceProfileName=instance_profile_name, | |
RoleName=instance_role_name, | |
) | |
logging.info("Waiting for instance profile is ready: %s", instance_profile_name) | |
instance_profile_arn = iam.get_instance_profile(InstanceProfileName=instance_profile_name)["InstanceProfile"][ | |
"Arn" | |
] | |
logging.info("Launching instance: %s", instance_name) | |
while True: | |
try: | |
r = ec2.run_instances( | |
ImageId=ami, | |
MinCount=1, | |
MaxCount=1, | |
InstanceType=args.instance_size, | |
InstanceInitiatedShutdownBehavior="terminate", | |
UserData=USER_DATA.format(username=username, id_rsa=id_rsa), | |
SubnetId=network_outputs[subnet_source], | |
TagSpecifications=[ | |
{ | |
"ResourceType": "instance", | |
"Tags": [{"Key": "Name", "Value": instance_name}], | |
} | |
], | |
IamInstanceProfile={ | |
"Arn": instance_profile_arn, | |
}, | |
BlockDeviceMappings=[ | |
{ | |
"DeviceName": "/dev/xvda", | |
"Ebs": { | |
"VolumeSize": args.disk_size, | |
"VolumeType": "standard", | |
}, | |
} | |
], | |
) | |
break | |
except botocore.exceptions.ClientError as e: | |
if e.__dict__["response"]["Error"]["Code"] == "InvalidParameterValue": | |
continue | |
raise | |
instance = r["Instances"][0] | |
instance_id = instance["InstanceId"] | |
teardown(ec2.terminate_instances, InstanceIds=[instance_id], DryRun=False) | |
logging.info("Launching instance waiting for readiness: %s", instance_id) | |
while True: | |
try: | |
signalnum = termination_queue.get_nowait() | |
raise CaughtSignal(signalnum) | |
except queue.Empty: | |
pass | |
instances = {} | |
for reservation in ec2.describe_instances(InstanceIds=[instance_id]).get("Reservations", []): | |
instances.update({e["InstanceId"]: e for e in reservation.get("Instances", [])}) | |
if instance_id not in instances: | |
time.sleep(1) | |
continue | |
if instances[instance_id]["State"]["Name"] != "running": | |
time.sleep(1) | |
continue | |
break | |
logging.info("Instance Online at: %s", instance_id) | |
logging.info("Waiting for instance to become available") | |
retcode = -1 | |
start_time = time.time() | |
while retcode != 0: | |
retcode = cmd_wrap(["aws", "ssm", "start-session", "--target", instance_id]) | |
if (time.time() - start_time) > args.ssm_timeout: | |
break | |
finally: | |
if args.wait_before_teardown: | |
input("Press any key to proceed with teardown operations...") | |
do_teardown() | |
if __name__ == "__main__": | |
main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment