Created
January 24, 2024 14:11
-
-
Save egordm/f38b11143d5b1aa5f2680f32981ac6f4 to your computer and use it in GitHub Desktop.
Luminis Blog — Deploying SageMaker Pipelines Using CDK
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Define the pipeline (this step uploads required code and packages by the pipeline to S3) | |
pipeline = pipeline_factory.create( | |
pipeline_name=pipeline_name, | |
role=sm_execution_role_arn, | |
sm_session=sagemaker_session, | |
) | |
pipeline_def_json = json.dumps(json.loads(pipeline.definition()), indent=2, sort_keys=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_pipeline_resource( | |
self, | |
pipeline_name: str, | |
pipeline_factory: SagemakerPipelineFactory, | |
sources_bucket_name: str, | |
sm_execution_role_arn: str, | |
) -> Tuple[sm.CfnPipeline, str]: | |
... | |
# Define the pipeline (this step uploads required code and packages by the pipeline to S3) | |
... | |
# Define CloudFormation resource for the pipeline, so it can be deployed to your account | |
pipeline_cfn = sm.CfnPipeline( | |
self, | |
id=f"SagemakerPipeline-{pipeline_name}", | |
pipeline_name=pipeline_name, | |
pipeline_definition={"PipelineDefinitionBody": pipeline_def_json}, | |
role_arn=sm_execution_role_arn, | |
) | |
arn = self.format_arn( | |
service='sagemaker', | |
resource='pipeline', | |
resource_name=pipeline_cfn.pipeline_name, | |
) | |
return pipeline_cfn, arn |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load infrastructure stack outputs as value parameters (resolved at cdk deploy time) | |
sources_bucket_name = ssm.StringParameter.value_from_lookup( | |
self, f"/{self.prefix}/SourcesBucketName") | |
sm_execution_role_arn = ssm.StringParameter.value_from_lookup( | |
self, f"/{self.prefix}/SagemakerExecutionRoleArn") | |
# Create a configured pipeline | |
self.example_pipeline, self.example_pipeline_arn = self.create_pipeline_resource( | |
pipeline_name='example-pipeline', | |
pipeline_factory=ExamplePipeline( | |
pipeline_config_parameter="Hello world!" | |
), | |
sources_bucket_name=sources_bucket_name, | |
sm_execution_role_arn=sm_execution_role_arn, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SagemakerPipelineFactory(BaseModel): | |
"""Base class for all pipeline factories.""" | |
@abstractmethod | |
def create( | |
self, | |
role: str, | |
pipeline_name: str, | |
sm_session: sagemaker.Session, | |
) -> Pipeline: | |
raise NotImplementedError |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ExamplePipeline(SagemakerPipelineFactory): | |
pipeline_config_parameter: str | |
def create( | |
self, | |
role: str, | |
pipeline_name: str, | |
sm_session: sagemaker.Session, | |
) -> Pipeline: | |
... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
# Use the SKLearn image provided by AWS SageMaker | |
image_uri = sagemaker.image_uris.retrieve( | |
framework="sklearn", | |
region=sm_session.boto_region_name, | |
version="0.23-1", | |
) | |
# Create a ScriptProcessor and add code / run parameters | |
processor = ScriptProcessor( | |
image_uri=image_uri, | |
command=["python3"], | |
instance_type=instance_type_var, | |
instance_count=1, | |
role=role, | |
sagemaker_session=sm_session, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
processing_step = ProcessingStep( | |
name="processing-example", | |
step_args=processor.run( | |
code="pipelines/sources/example_pipeline/evaluate.py", | |
), | |
job_arguments=[ | |
"--config_parameter", self.pipeline_config_parameter | |
], | |
inputs=[], | |
outputs=[] | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
return Pipeline( | |
name=pipeline_name, | |
steps=[processing_step], | |
sagemaker_session=sm_session, | |
parameters=[instance_type_var], | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config_parameter", type=str) | |
args = parser.parse_args() | |
print(f"Hello {args.config_parameter}!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cd ./data_project | |
cdk deploy |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cd ./infrastructure_project | |
cdk deploy |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
role = iam.Role( | |
self, 'SagemakerExecutionRole', | |
assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'), | |
role_name=f"{self.prefix}-sm-execution-role", | |
managed_policies=[ | |
iam.ManagedPolicy.from_managed_policy_arn( | |
self, | |
id="SagemakerFullAccess", | |
managed_policy_arn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess" | |
), | |
], | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
self.sm_sources_bucket = s3.Bucket( | |
self, | |
id="SourcesBucket", | |
bucket_name=f"{self.prefix}-sm-sources", | |
lifecycle_rules=[], | |
versioned=False, | |
removal_policy=cdk.RemovalPolicy.DESTROY, | |
auto_delete_objects=True, | |
# Access | |
access_control=s3.BucketAccessControl.PRIVATE, | |
block_public_access=s3.BlockPublicAccess.BLOCK_ALL, | |
public_read_access=False, | |
object_ownership=s3.ObjectOwnership.OBJECT_WRITER, | |
enforce_ssl=True, | |
# Encryption | |
encryption=s3.BucketEncryption.S3_MANAGED, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Grant read access to SageMaker execution role | |
self.sm_sources_bucket.grant_read(self.sm_execution_role) | |
# Grant read/write access to SageMaker execution role | |
self.sm_data_bucket.grant_read_write(self.sm_execution_role) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Fetch VPC information | |
vpc_name = self.node.try_get_context("vpc_name") | |
self.vpc = ec2.Vpc.from_lookup( | |
self, id="ImportedVpc", | |
vpc_name=vpc_name if vpc_name else f"{self.prefix}-vpc" | |
) | |
public_subnet_ids = [public_subnet.subnet_id for public_subnet in self.vpc.public_subnets] | |
# Create SageMaker Studio domain | |
self.domain = sm.CfnDomain( | |
self, "SagemakerDomain", | |
auth_mode='IAM', | |
domain_name=f'{self.prefix}-SG-Project', | |
default_user_settings=sm.CfnDomain.UserSettingsProperty( | |
execution_role=self.sm_execution_role.role_arn | |
), | |
app_network_access_type='PublicInternetOnly', | |
vpc_id=self.vpc.vpc_id, | |
subnet_ids=public_subnet_ids, | |
tags=[cdk.CfnTag( | |
key="project", | |
value="example-pipelines" | |
)], | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create SageMaker Studio default user profile | |
self.user = sm.CfnUserProfile( | |
self, 'SageMakerStudioUserProfile', | |
domain_id=self.domain.attr_domain_id, | |
user_profile_name='default-user', | |
user_settings=sm.CfnUserProfile.UserSettingsProperty() | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
vpc = ec2.Vpc( | |
self, | |
id="VpcConstruct", | |
ip_addresses=ec2.IpAddresses.cidr("10.0.0.0/16"), | |
vpc_name=f"{self.prefix}-vpc", | |
max_azs=3, | |
nat_gateways=1, | |
subnet_configuration=[ | |
ec2.SubnetConfiguration( | |
cidr_mask=24, | |
name="Public", | |
subnet_type=ec2.SubnetType.PUBLIC, | |
), | |
ec2.SubnetConfiguration( | |
cidr_mask=23, | |
name="Private", | |
subnet_type=ec2.SubnetType.PRIVATE_WITH_EGRESS, | |
), | |
ec2.SubnetConfiguration( | |
cidr_mask=24, | |
name="Isolated", | |
subnet_type=ec2.SubnetType.PRIVATE_ISOLATED, | |
), | |
], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment