Created
September 6, 2020 13:36
-
-
Save tuliocasagrande/55611d1deb278cfbf09354b1d85d411c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import boto3 | |
CLIENT = boto3.client('sagemaker') | |
SAGEMAKER_ROLE_ARN = os.environ['SAGEMAKER_ROLE_ARN'] | |
class ResourcePending(Exception): | |
pass | |
class ResourceFailed(Exception): | |
pass | |
def _check_job_status(response): | |
# Valid Values: InProgress | Completed | Failed | Stopping | Stopped | |
if response['ProcessingJobStatus'] in {'InProgress', 'Stopping'}: | |
raise ResourcePending | |
elif response['ProcessingJobStatus'] in {'Failed', 'Stopped'}: | |
raise ResourceFailed(response.get('FailureReason', '')) | |
def lambda_handler(event, context): | |
print('New event:', event) | |
job_name = event['ProcessingJobName'] | |
try: | |
response = CLIENT.describe_processing_job(ProcessingJobName=job_name) | |
except CLIENT.exceptions.ClientError: | |
print('Creating new processing job:', job_name) | |
_create_processing_job(event) | |
response = CLIENT.describe_processing_job(ProcessingJobName=job_name) | |
_check_job_status(response) | |
return json.dumps(response, default=str) | |
def _create_processing_job(event): | |
job_name = event['ProcessingJobName'] | |
image_uri = event['ImageUri'] | |
entrypoint = event['Entrypoint'] | |
inputs_config = event['InputsConfig'] | |
outputs_config = event['OutputsConfig'] | |
arguments = event.get('Arguments') | |
instance_type = event.get('InstanceType', 'ml.m5.2xlarge') | |
instance_count = event.get('InstanceCount', 1) | |
volume_size_in_gb = event.get('VolumeSizeInGB', 30) | |
# Creating inputs | |
inputs = [_make_input(input_) for input_ in inputs_config.values()] | |
# Creating outputs | |
outputs = [_make_output(output) for output in outputs_config.values()] | |
# Create Processing Job | |
process_request = { | |
'ProcessingJobName': job_name, | |
'ProcessingResources': { | |
'ClusterConfig': { | |
'InstanceType': instance_type, | |
'InstanceCount': instance_count, | |
'VolumeSizeInGB': volume_size_in_gb | |
} | |
}, | |
'AppSpecification': { | |
'ImageUri': image_uri, | |
'ContainerEntrypoint': entrypoint | |
}, | |
'RoleArn': SAGEMAKER_ROLE_ARN, | |
'ProcessingInputs': inputs, | |
'ProcessingOutputConfig': { | |
'Outputs': outputs | |
} | |
} | |
if arguments: | |
process_request['AppSpecification']['ContainerArguments'] = arguments | |
CLIENT.create_processing_job(**process_request) | |
def _make_input(input_config): | |
return { | |
'InputName': input_config['InputName'], | |
'S3Input': { | |
'S3Uri': input_config['S3Uri'], | |
'LocalPath': '/opt/ml/processing/' + input_config['InputName'], | |
'S3DataType': 'S3Prefix', | |
'S3InputMode': 'File' | |
} | |
} | |
def _make_output(output_config): | |
return { | |
'OutputName': output_config['OutputName'], | |
'S3Output': { | |
'S3Uri': output_config['S3Uri'], | |
'LocalPath': '/opt/ml/processing/' + output_config['OutputName'], | |
'S3UploadMode': 'EndOfJob' | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment