Skip to content

Instantly share code, notes, and snippets.

@wrouesnel
Created March 21, 2025 15:19
Show Gist options
  • Save wrouesnel/5589ccec92b035b0f1287d6ba289a7ef to your computer and use it in GitHub Desktop.
Save wrouesnel/5589ccec92b035b0f1287d6ba289a7ef to your computer and use it in GitHub Desktop.
Wrapper script for launching an "on-demand" ephemeral EC2 instance
#!/usr/bin/env python3
# Wrapper script to launch an ephemeral EC2 instance
import os
import time
import sys
import argparse
import boto3
import logging
import socket
import getpass
import random
import subprocess
import signal
import queue
import threading
import base64
import json
import botocore
import signal
rand = random.SystemRandom()
from typing import Mapping
logHandler = logging.StreamHandler(sys.stderr)
formatter = logging.Formatter("%(asctime)s:%(name)s:%(funcName)s:%(pathname)s:%(levelname)s:%(lineno)d: %(message)s")
logHandler.setFormatter(formatter)
logging.basicConfig(
handlers=[logHandler],
level=logging.INFO,
)
logger = logging.getLogger(__name__)
termination_queue = queue.Queue()
# Setup a graceful signal handler now.
def signal_handler(signum, frm):
logging.info(f"Signal caught: {signal.strsignal(sig)}")
termination_queue.put(signum)
for sig in (signal.Signals.SIGINT, signal.Signals.SIGTERM, signal.Signals.SIGQUIT, signal.Signals.SIGHUP):
termination_queue.put(signum)
class CaughtSignal(Exception):
"""Exception raised when a signal is caught and application should shutdown"""
def __init__(signalnum):
super().__init__(f"Shutdown due to signal: {signal.strsignal(signalnum)}")
NETWORK_STACK = "ocd-vpc-shared-master-cf"
class KeyValueJsonAction(argparse.Action):
"""
Argparse action for generating dictionaries from key-value args with JSON support
"""
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
result = {}
for entry in values:
key, value = entry.split("=", 1)
# Try and load value as a JSON object. If that fails, treat it as a string.
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = value
result[key] = parsed_value
setattr(namespace, self.dest, result)
def boto_results(key, fn, *args, **kwargs):
result = []
next_token = None
while True:
r = fn(*args, **kwargs)
next_token = r.get("NextToken")
value = r[key]
if isinstance(value, (str, bytes)):
result.append(value)
elif isinstance(value, Mapping):
result.append(value)
else:
result.extend(value)
if next_token is None:
break
kwargs.update({"NextToken": next_token})
return result
def unique_id() -> str:
"""Generate a lower case, short unique identifier"""
# Get 128 random bits
value = bytearray(random.getrandbits(8) for _ in range(16))
result = base64.b32encode(value).decode("utf8").rstrip("=")
return result.lower()
USER_DATA = """#!/bin/bash -x
sudo -i
_build_start=$(date +%s)
export HOME="/root"
export LOG_FILE="$HOME/userdata.log"
mkdir -p $(dirname "$LOG_FILE")
touch "$LOG_FILE"
chmod 600 "$LOG_FILE"
# Close standard output file descriptor
exec 1<&-
# Close standard error file descriptor
exec 2<&-
# Open standard output as $LOG_FILE file for read and write.
exec 1<>$LOG_FILE
# Redirect standard error to standard output
exec 2>&1
export PATH="/opt/aws/bin/:$PATH"
export AWS_REGION="$(cloud-init query region)"
export AWS_DEFAULT_REGION="$(cloud-init query region)"
function fatal() {{
# Fail the startup.
echo "$*" 1>&2
# Force the instance to die quickly.
shutdown -h now
exit 1
}}
export AWS_INSTANCE_ID="$(cloud-init query instance_id)"
if [ $? != 0 ]; then
fatal "Failed to get instance-id"
fi
cd "$HOME" || fatal "Could not change to the "$HOME" directory in userinit"
if ! adduser "{username}"; then
fatal "Could not create user"
fi
mkdir -p "/home/{username}/.ssh"
chmod 700 "/home/{username}/.ssh"
echo "{id_rsa}" >> "/home/{username}/.ssh/authorized_keys"
chmod 600 "/home/{username}/.ssh/authorized_keys"
chown -R {username}:{username} /home/{username}/.ssh
echo "{username} ALL=(ALL) ALL" >> /etc/sudoers
echo "Build Finished: $(( $(date +%s) - _build_start )) seconds"
"""
def cmd_wrap(cmd):
env = os.environ.copy()
p = subprocess.Popen(
cmd, env=env, cwd=os.path.realpath(os.curdir), stdin=sys.stdin, stderr=sys.stderr, stdout=sys.stdout
)
# Attach all signals and forward them to the subprocess
def sighandler(signum, stack):
p.send_signal(signum)
for i in [x for x in dir(signal) if x.startswith("SIG")]:
try:
signum = getattr(signal, i)
signal.signal(signum, sighandler)
except (OSError, RuntimeError, ValueError) as m: # OSError for Python3, RuntimeError for 2
# print("Skipping {}".format(i))
pass
p.wait()
return p.returncode
def main(argv):
parser = argparse.ArgumentParser(
description="Launch an EC2 instance for updating the S3 instance",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--instance-size", default="t2.small", help="AWS instance size to launch instance as")
parser.add_argument("--ssm-timeout", default=180, help="Timeout for attempts to make an SSM connection")
parser.add_argument("--wait-before-teardown", action="store_true", help="Pause before tearing down resources")
parser.add_argument("--subnet-group", default="Public", help="Subnet group to select from")
parser.add_argument(
"--iam-permissions",
action=KeyValueJsonAction,
nargs="*",
help="Specify IAM permissions as <resource>=<permissions>. Allow is implied, wildcards allowed, multiple permissions may be specified in JSON form.",
)
parser.add_argument(
"--s3-bucket-allow", type=str, nargs="*", help="List of S3 buckets to grant full permissions too"
)
parser.add_argument("--disk-size", type=int, default=100, help="Disk size in gigabytes of the root disk")
args = parser.parse_args(argv)
cf = boto3.client("cloudformation")
ec2 = boto3.client("ec2")
ssm = boto3.client("ssm")
iam = boto3.client("iam")
id_rsa = open(os.path.join(os.environ["HOME"], ".ssh", "id_rsa.pub"), "rt").read().strip()
logging.info("Getting latest Amazon Image ID")
ami = ssm.get_parameter(Name="/aws/service/ami-amazon-linux-latest/amzn2-ami-hvm-x86_64-gp2")["Parameter"]["Value"]
logging.info("Launching using AMI: %s", ami)
network_stack = boto_results("Stacks", cf.describe_stacks, StackName=NETWORK_STACK)[0]
network_outputs = {e["OutputKey"]: e["OutputValue"] for e in network_stack["Outputs"]}
# TODO: scan the subnets in the target account, and find one which matches our SSM requirements
subnet_source = rand.choice([k for k in network_outputs if k.startswith(f"o{args.subnet_group}Subnet")])
username = getpass.getuser()
# Functions run in reverse order
teardown_functions = []
def teardown(fn, *args, **kwargs):
teardown_functions.append((fn, args, kwargs))
def do_teardown():
for fn, args, kwargs in reversed(teardown_functions):
logging.info(f"Teardown function: {fn.__name__}, {args}, {kwargs}")
fn(*args, **kwargs)
identifier = unique_id()
instance_name = f"e-{socket.gethostname()}-{username}-{identifier}"
instance_profile_name = f"e-{identifier}-InstanceProfile"
instance_role_name = f"e-{identifier}-InstanceRole"
instance_inline_policy_name = f"e-{identifier}-InlinePolicy"
instance_role_assume_role_policy = {
"Version": "2012-10-17",
"Statement": {
"Effect": "Allow",
"Principal": {"Service": ["ec2.amazonaws.com"]},
"Action": "sts:AssumeRole",
},
}
managed_role_arns = [
"arn:aws:iam::aws:policy/CloudWatchAgentServerPolicy",
"arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore",
# "arn:aws:iam::aws:policy/service-role/AmazonSSMAutomationRole",
"arn:aws:iam::aws:policy/EC2InstanceProfileForImageBuilder",
]
instance_inline_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"logs:CreateLogGroup",
"logs:CreateLogStream",
"logs:PutLogEvents",
"logs:DescribeLogStreams",
],
"Resource": "arn:aws:logs:*:*:*",
},
{
"Effect": "Allow",
"Action": [
"ssm:StartAssociationsOnce",
"ssm:CreateAssociation",
"ssm:CreateAssociationBatch",
"ssm:UpdateAssociation",
],
"Resource": "*",
},
],
}
# Parse the command line policy requests
if args.iam_permissions is not None:
for resource, iam_action in args.iam_permissions.items():
statement = {
"Effect": "Allow",
"Action": iam_action if isinstance(iam_action, str) else iam_action,
"Resource": resource,
}
instance_inline_policy["Statement"].append(statement)
if (args.s3_bucket_allow is not None) and (len(args.s3_bucket_allow) > 0):
statement = {
"Effect": "Allow",
"Action": "s3:*",
"Resource": [f"arn:aws:s3:::{bucket}" for bucket in args.s3_bucket_allow],
}
instance_inline_policy["Statement"].append(statement)
statement = {
"Effect": "Allow",
"Action": "s3:*",
"Resource": [f"arn:aws:s3:::{bucket}/*" for bucket in args.s3_bucket_allow],
}
instance_inline_policy["Statement"].append(statement)
try:
logging.info("Creating role: %s", instance_role_name)
instance_role_arn = iam.create_role(
Path="/service/",
RoleName=instance_role_name,
Description="Role for temporary instance",
Tags=[{"Key": "instance-name", "Value": instance_name}],
AssumeRolePolicyDocument=json.dumps(instance_role_assume_role_policy),
)["Role"]["Arn"]
teardown(iam.delete_role, RoleName=instance_role_name)
# Attach managed roles
for managed_role_arn in managed_role_arns:
logging.info("Attaching managed roles: %s", managed_role_arn)
iam.attach_role_policy(RoleName=instance_role_name, PolicyArn=managed_role_arn)
teardown(iam.detach_role_policy, RoleName=instance_role_name, PolicyArn=managed_role_arn)
logging.info("Putting inline policy: %s", instance_inline_policy_name)
iam.put_role_policy(
RoleName=instance_role_name,
PolicyName=instance_inline_policy_name,
PolicyDocument=json.dumps(instance_inline_policy),
)
teardown(iam.delete_role_policy, PolicyName=instance_inline_policy_name, RoleName=instance_role_name)
logging.info("Setting up an instance profile for the ephemeral instance: %s", instance_profile_name)
iam.create_instance_profile(InstanceProfileName=instance_profile_name, Path="/service/")
teardown(iam.delete_instance_profile, InstanceProfileName=instance_profile_name)
logging.info("Add role to instance profile: %s <- %s", instance_profile_name, instance_role_name)
iam.add_role_to_instance_profile(InstanceProfileName=instance_profile_name, RoleName=instance_role_name)
teardown(
iam.remove_role_from_instance_profile,
InstanceProfileName=instance_profile_name,
RoleName=instance_role_name,
)
logging.info("Waiting for instance profile is ready: %s", instance_profile_name)
instance_profile_arn = iam.get_instance_profile(InstanceProfileName=instance_profile_name)["InstanceProfile"][
"Arn"
]
logging.info("Launching instance: %s", instance_name)
while True:
try:
r = ec2.run_instances(
ImageId=ami,
MinCount=1,
MaxCount=1,
InstanceType=args.instance_size,
InstanceInitiatedShutdownBehavior="terminate",
UserData=USER_DATA.format(username=username, id_rsa=id_rsa),
SubnetId=network_outputs[subnet_source],
TagSpecifications=[
{
"ResourceType": "instance",
"Tags": [{"Key": "Name", "Value": instance_name}],
}
],
IamInstanceProfile={
"Arn": instance_profile_arn,
},
BlockDeviceMappings=[
{
"DeviceName": "/dev/xvda",
"Ebs": {
"VolumeSize": args.disk_size,
"VolumeType": "standard",
},
}
],
)
break
except botocore.exceptions.ClientError as e:
if e.__dict__["response"]["Error"]["Code"] == "InvalidParameterValue":
continue
raise
instance = r["Instances"][0]
instance_id = instance["InstanceId"]
teardown(ec2.terminate_instances, InstanceIds=[instance_id], DryRun=False)
logging.info("Launching instance waiting for readiness: %s", instance_id)
while True:
try:
signalnum = termination_queue.get_nowait()
raise CaughtSignal(signalnum)
except queue.Empty:
pass
instances = {}
for reservation in ec2.describe_instances(InstanceIds=[instance_id]).get("Reservations", []):
instances.update({e["InstanceId"]: e for e in reservation.get("Instances", [])})
if instance_id not in instances:
time.sleep(1)
continue
if instances[instance_id]["State"]["Name"] != "running":
time.sleep(1)
continue
break
logging.info("Instance Online at: %s", instance_id)
logging.info("Waiting for instance to become available")
retcode = -1
start_time = time.time()
while retcode != 0:
retcode = cmd_wrap(["aws", "ssm", "start-session", "--target", instance_id])
if (time.time() - start_time) > args.ssm_timeout:
break
finally:
if args.wait_before_teardown:
input("Press any key to proceed with teardown operations...")
do_teardown()
if __name__ == "__main__":
main(sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment