Skip to content

Instantly share code, notes, and snippets.

@DomWeldon
Last active September 21, 2024 10:27
Show Gist options
  • Save DomWeldon/ce7e070283d97368cd9abc5be71b247d to your computer and use it in GitHub Desktop.
Save DomWeldon/ce7e070283d97368cd9abc5be71b247d to your computer and use it in GitHub Desktop.
SSM + pydantic: ARNs in environment variables are queried at load time

SSM + Pydantic

Query values from SSM when deployed, by placing an SSM ARN as the environment variable

Background

I wanted to query secrets from SSM at runtime, to laod them into a pydantic.BaseSettings settings object, but still be able to pass standard values during development (and I guess, if I want, in prod).

I've done a couple of similar implementations before, but they have always felt clunky and involved altering the object after instantiation, or hard coding which values to take out of SSM.

Implementation

I read the master branch of pydantic, saw the code and was able to implement it using the new customise_sources() classmethod on the Config object. I look forward to when this feature is released into pydantic.

See issue 2107 on the pydantic repo.

"""Configuration for the API."""
# Standard Library
import secrets
import typing
# Third Party Libraries
import pydantic
from . import ssm_config
class Settings(pydantic.BaseSettings):
"""Settings object to collect from environment and SSM."""
# secrets
SECRET_KEY: str
# for serving locally
SERVER_NAME: str = "0.0.0.0"
SERVER_HOST: pydantic.AnyHttpUrl = "http://0.0.0.0"
# API
API_V1_STR: str = "/api/v1"
ENVIRONMENT_NAME: typing.Optional[str] = "dev"
# BACKEND_CORS_ORIGINS is a JSON-formatted list of origins
BACKEND_CORS_ORIGINS: typing.List[pydantic.AnyHttpUrl] = []
# for swagger
OPENAPI_URL: str = "/openapi.json"
SWAGGER_UI_OAUTH2_REDIRECT_URL: str = "/docs/oauth2-redirect"
SWAGGER_UI_INIT_OAUTH: typing.Optional[typing.Dict[str, typing.Any]] = None
DOCS_URL: str = "/docs"
@pydantic.validator("BACKEND_CORS_ORIGINS", pre=True, allow_reuse=True)
def assemble_cors_origins(
cls, v: typing.Union[str, typing.List[str]]
) -> typing.Union[typing.List[str], str]:
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
elif isinstance(v, (list, str)):
return v
raise ValueError(v)
PROJECT_TITLE: str = "Template API Lambda"
SENTRY_DSN: typing.Optional[pydantic.HttpUrl] = None
class Config:
case_sensitive = True
@classmethod
def customise_sources(
cls,
init_settings: ssm_config.SettingsSourceCallable,
env_settings: ssm_config.SettingsSourceCallable,
file_secret_settings: ssm_config.SettingsSourceCallable,
) -> typing.Tuple[ssm_config.SettingsSourceCallable, ...]:
return (
init_settings,
ssm_config.SSMAwareEnvSettingsSource(
env_settings.env_file,
env_settings.env_file_encoding,
),
file_secret_settings
)
settings = Settings()
__all__ = ["settings"]
# Standard Library
import itertools
import typing
# Third Party Libraries
import pytest
@pytest.fixture(scope="function")
def mock_ssm_values(moto_account_id: int) -> typing.Dict[str, typing.Dict[str, str]]:
"""Consistent values for mocking in SSM"""
return {
"String": {
f"arn:aws:ssm:eu-west-2:{moto_account_id}:parameter/test/woof": "grr",
f"arn:aws:ssm:eu-west-2:{moto_account_id}:parameter/test/meow": "purr",
},
"SecureString": {
f"arn:aws:ssm:eu-west-2:{moto_account_id}:parameter/test/bleat": "baaa",
f"arn:aws:ssm:eu-west-2:{moto_account_id}:parameter/test/moo": "mooo",
},
}
@pytest.fixture(scope="function")
def mock_ssm_values_env(
mock_ssm_values: typing.Dict[str, typing.Dict[str, str]]
) -> typing.Dict:
flat_ssm_vals = list(
itertools.chain.from_iterable(
# interpolate parent keys into flat list to pass to moto
[(k, v, j) for k, v in u.items()]
for j, u in mock_ssm_values.items()
)
)
plain_env_vars = {"SOUND_BIRD": "TWITWOO"}
ssm_arn_env_vars = {
"ANIMAL_" + k.split("/")[-1].upper(): k for k, v, _ in flat_ssm_vals
}
all_env_vars = {
# mock up our environment vars
**plain_env_vars,
**ssm_arn_env_vars,
}
expected_config = {
**plain_env_vars,
**{"ANIMAL_" + k.split("/")[-1].upper(): v for k, v, _ in flat_ssm_vals},
}
return {
"flat_ssm_vals": flat_ssm_vals,
"expected_config": expected_config,
"all_env_vars": all_env_vars,
"ssm_arn_env_vars": ssm_arn_env_vars,
"plain_env_vars": plain_env_vars,
}
"""Classes to get secrets out of SSM."""
# Standard Library
import os
import pathlib
import typing
# Third Party Libraries
import boto3
import pydantic
from mypy_boto3_ssm.client import SSMClient
from pydantic.env_settings import SettingsError, read_env_file
SettingsSourceCallable = typing.Callable[['BaseSettings'], typing.Dict[str, typing.Any]]
def _get_ssm_client() -> SSMClient:
"""Create the SSM client"""
return boto3.client("ssm")
def _is_ssm_arn(val: str) -> bool:
"""Is the value an SSM ARN?"""
if not hasattr(val, "startswith"):
return False
arn_parts = val.split(":")
if len(arn_parts) < 6:
return False
return val.startswith("arn:aws:ssm") and arn_parts[5].startswith(
"parameter"
)
class SSMAwareEnvSettingsSource:
"""Environment, with SSM ARNs replaced with their plaintext values
Based on pydantic's EnvSettingsSource, this class will return the
environment from ``os.environ``, but will replace any instances where
the value is an SSM ARN with the unencrypted value from SSM.
It assumes it is operating inside a Lambda function or similar and so
does not attempt to configure credentials for AWS in any way."""
__slots__ = ("env_file", "env_file_encoding")
def __init__(
self,
env_file: typing.Union[pathlib.Path, str, None],
env_file_encoding: typing.Optional[str],
):
self.env_file: typing.Union[pathlib.Path, str, None] = env_file
self.env_file_encoding: typing.Optional[str] = env_file_encoding
def __call__(
self, settings: pydantic.BaseSettings
) -> typing.Dict[str, typing.Any]:
"""
Build environment variables suitable for passing to the Model.
"""
d: typing.Dict[str, typing.Optional[str]] = {}
if settings.__config__.case_sensitive:
env_vars: typing.Mapping[str, typing.Optional[str]] = os.environ
else:
env_vars = {k.lower(): v for k, v in os.environ.items()}
if self.env_file is not None:
env_path = pathlib.Path(self.env_file).expanduser()
if env_path.is_file():
env_vars = {
**read_env_file(
env_path,
encoding=self.env_file_encoding,
case_sensitive=settings.__config__.case_sensitive,
),
**env_vars,
}
# SSM aware section
env_vars = self._replace_arns_with_ssm_vals(env_vars)
for field in settings.__fields__.values():
env_val: typing.Optional[str] = None
for env_name in field.field_info.extra["env_names"]:
env_val = env_vars.get(env_name)
if env_val is not None:
break
if env_val is None:
continue
if field.is_complex():
try:
cfg = settings.__config__
env_val = cfg.json_loads(env_val) # type: ignore
except ValueError as e:
raise SettingsError(
f'error parsing JSON for "{env_name}"'
) from e
d[field.alias] = env_val
return d
def _replace_arns_with_ssm_vals(
self, env_vars: typing.Dict[str, typing.Any]
) -> typing.Dict[str, typing.Any]:
"""Replace values with SSM value if the original is an SSM ARN."""
# get the ARNs we'll need
ssm_arns = {v for v in env_vars.values() if _is_ssm_arn(v)}
# if no SSM vars we can skip looking them up
if not ssm_arns:
return env_vars
# lookup the ARNs as a dict
client = _get_ssm_client()
resp = client.get_parameters(
Names=list(
arn.split(":")[-1][len("parameter") :] for arn in ssm_arns
),
WithDecryption=True,
)
plaintexts = {
param["ARN"]: param["Value"] for param in resp["Parameters"]
}
return {
k: v if v not in plaintexts else plaintexts[v]
for k, v in env_vars.items()
}
def __repr__(self) -> str:
return (
f"SSMAwareEnvSettingsSource(env_file={self.env_file!r}, "
f"env_file_encoding={self.env_file_encoding!r})"
)
import importlib
import secrets
import boto3
import moto
@moto.mock_ssm
def test_settings_work_with_ssm(mock_ssm_values_env, monkeypatch, moto_account_id):
"""Are SSM values made available?"""
# arrange
monkeypatch.setenv(
"SECRET_KEY",
f"arn:aws:ssm:eu-west-2:{moto_account_id}:parameter/test/secret_key"
)
secret_key = secrets.token_urlsafe(64)
client = boto3.client("ssm")
client.put_parameter(
Name="/test/secret_key",
Value=secret_key,
Type="SecureString",
)
# act
from app.core import config
importlib.reload(config)
# assert
assert config.settings.SECRET_KEY == secret_key
# Standard Library
import typing
# Third Party Libraries
import boto3
import moto
def test__is_ssm_arn(moto_account_id: int):
"""Is the value an SSM arn?"""
# arrange
# App and Model Imports
from app.core import ssm_config
true_ssm_arns = [
"arn:aws:ssm:us-west-2:aws-account-ID:parameter/dev/doc1",
f"arn:aws:ssm:eu-west-2:{moto_account_id}:parameter/dev/doc2",
]
false_ssm_arns = [
"",
"some text",
"arn:aws:ssm:kinda",
f"arn:aws:ssm:eu-west-2:{moto_account_id}:notaparameter/dev/doc2",
"arn:partition:service:region:account-id:resource-type:resource-id",
None,
1,
]
# act
true_vals = {ssm_config._is_ssm_arn(v) for v in true_ssm_arns}
false_vals = {ssm_config._is_ssm_arn(v) for v in false_ssm_arns}
# assert
assert true_vals == {True}
assert false_vals == {False}
@moto.mock_ssm
def test__replace_arns_with_ssm_vals(mock_ssm_values_env: typing.Dict):
"""Does it replace SSM values as expected?"""
# arrange
# App and Model Imports
from app.core import ssm_config
# create in moto
client = boto3.client("ssm")
for arn, val, type_ in mock_ssm_values_env["flat_ssm_vals"]:
client.put_parameter(
Name=arn.split(":")[-1][len("parameter") :],
Value=val,
Type=type_,
)
# act
source = ssm_config.SSMAwareEnvSettingsSource(None, None)
replaced_vals = source._replace_arns_with_ssm_vals(
mock_ssm_values_env["all_env_vars"]
)
# assert
# all SSM ARNs were detected
assert all(
ssm_config._is_ssm_arn(mock_ssm_values_env["all_env_vars"][k])
for k in mock_ssm_values_env["ssm_arn_env_vars"]
)
# and replaced with their actual values
assert replaced_vals == mock_ssm_values_env["expected_config"]
@wrgeorge1983
Copy link

Thanks for this! turns out I'm actually using 'secretsmanager' instead of 'ssm' but close enough to the same thing to make it fit with some tweaks! Esp. the moto testing. I had looked at moto in the past and it never really clicked. Super helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment