From c6acabf95a9218c15787ee423fff56039f692285 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Mon, 1 Jun 2026 00:18:14 -0700 Subject: [PATCH] fix: Local processing job requires role, but doesn't use it (5562) --- sagemaker/processing.py | 1077 ++++++++++++++++++++++ tests/unit/test_processing_local_mode.py | 985 ++++++++++++++++++++ 2 files changed, 2062 insertions(+) create mode 100644 sagemaker/processing.py create mode 100644 tests/unit/test_processing_local_mode.py diff --git a/sagemaker/processing.py b/sagemaker/processing.py new file mode 100644 index 0000000000..ca3b0eec62 --- /dev/null +++ b/sagemaker/processing.py @@ -0,0 +1,1077 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code related to the Processor class. + +which is used for Amazon SageMaker Processing Jobs. These jobs let users +perform data pre-processing, post-processing, feature engineering, +data validation, and model evaluation, and interpretation on Amazon SageMaker. +""" +from __future__ import absolute_import + +import logging +import os +import pathlib +import re +import shutil +import tempfile +import urllib.parse + +from textwrap import dedent + +from sagemaker import s3, image_uris, vpc_utils +from sagemaker.config import ( + PROCESSING_JOB_ENVIRONMENT_PATH, + PROCESSING_JOB_INPUTS_S3_INPUT_S3_URI_PATH, + PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, + PROCESSING_JOB_KMS_KEY_ID_PATH, + PROCESSING_JOB_NETWORK_CONFIG_PATH, + PROCESSING_JOB_OUTPUTS_S3_OUTPUT_S3_URI_PATH, + PROCESSING_JOB_PROCESSING_RESOURCES_CLUSTER_CONFIG_INSTANCE_COUNT_PATH, + PROCESSING_JOB_PROCESSING_RESOURCES_CLUSTER_CONFIG_INSTANCE_TYPE_PATH, + PROCESSING_JOB_PROCESSING_RESOURCES_CLUSTER_CONFIG_VOLUME_KMS_KEY_ID_PATH, + PROCESSING_JOB_PROCESSING_RESOURCES_CLUSTER_CONFIG_VOLUME_SIZE_IN_GB_PATH, + PROCESSING_JOB_ROLE_ARN_PATH, + PROCESSING_JOB_SECURITY_GROUP_IDS_PATH, + PROCESSING_JOB_SUBNETS_PATH, + PROCESSING_JOB_TAGS_PATH, +) +from sagemaker.local import LocalSession +from sagemaker.network import NetworkConfig +from sagemaker.session import Session +from sagemaker.utils import ( + base_name_from_image, + get_config_value, + name_from_base, + resolve_value_from_config, + check_and_get_run_experiment_config, + format_tags, + Tags, +) +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import utilities as workflow_utilities +from sagemaker.common_utils import LogState +from sagemaker.apiutils import _utils as apiutils + +try: + from sagemaker.utils.code_injection import codec +except ImportError: + codec = None + +try: + from sagemaker.apiutils._utils import serialize +except ImportError: + serialize = None + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +def _is_local_mode(instance_type, sagemaker_session=None): + """Determine if the processor is running in local mode. + + Args: + instance_type (str): The instance type. + sagemaker_session: The SageMaker session. + + Returns: + bool: True if running in local mode. + """ + if instance_type is not None and str(instance_type).startswith("local"): + return True + if sagemaker_session is not None and getattr(sagemaker_session, "local_mode", False): + return True + return False + + +class Processor(object): + """Handles Amazon SageMaker Processing tasks.""" + + def __init__( + self, + role=None, + image_uri=None, + instance_count=None, + instance_type=None, + entrypoint=None, + volume_size_in_gb=30, + volume_kms_key=None, + output_kms_key=None, + max_runtime_in_seconds=None, + base_job_name=None, + sagemaker_session=None, + env=None, + tags: Tags = None, + network_config=None, + arguments=None, + ): + """Initializes a ``Processor`` instance. + + The ``Processor`` handles Amazon SageMaker Processing tasks. + + Args: + role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing + uses this role to access AWS resources, such as data stored in + Amazon S3. Required for non-local mode; optional when running + in local mode (instance_type='local'/'local_gpu' or + session.local_mode=True). + image_uri (str or PipelineVariable): The URI of the Docker image to use + for the processing jobs. + instance_count (int or PipelineVariable): The number of instances to run + a processing job with. + instance_type (str or PipelineVariable): The type of ML compute instance + to use for the processing job. + entrypoint (list[str]): The entrypoint for the processing job (default: None). + volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume + to use for storing data during processing (default: 30). + volume_kms_key (str or PipelineVariable): A KMS key for the processing + volume (default: None). + output_kms_key (str or PipelineVariable): The KMS key ID for processing + job outputs (default: None). + max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds + (default: None). + base_job_name (str): Prefix for processing job name. + sagemaker_session (:class:`~sagemaker.session.Session`): + Session object which manages interactions with Amazon SageMaker and + any other AWS services needed. + env (dict[str, str]): Environment variables to be passed to the + processing jobs (default: None). + tags (Optional[Tags]): Tags to be passed to the processing job + (default: None). + network_config (:class:`~sagemaker.network.NetworkConfig`): + A :class:`~sagemaker.network.NetworkConfig` object that configures + network isolation, encryption of inter-container traffic, + security group IDs, and subnets (default: None). + arguments (list[str]): A list of string arguments to be passed to a + processing job (default: None). + """ + self.role = role + self.image_uri = image_uri + self.instance_count = instance_count + self.instance_type = instance_type + self.entrypoint = entrypoint + self.volume_size_in_gb = volume_size_in_gb + self.volume_kms_key = volume_kms_key + self.output_kms_key = output_kms_key + self.max_runtime_in_seconds = max_runtime_in_seconds + self.base_job_name = base_job_name + self.env = env + self.tags = format_tags(tags) + self.network_config = network_config + self.arguments = arguments + self.jobs = [] + self.latest_job = None + self._current_job_name = None + + # Handle session creation for local mode + if instance_type is not None and str(instance_type).startswith("local"): + if sagemaker_session is None: + sagemaker_session = LocalSession() + elif sagemaker_session is None: + sagemaker_session = Session() + + self.sagemaker_session = sagemaker_session + + # Validate role: required for non-local mode, optional for local mode + if not _is_local_mode(instance_type, sagemaker_session): + if role is None and not is_pipeline_variable(role): + raise ValueError( + "AWS IAM role is required for non-local processing jobs. " + "Please provide a valid IAM role ARN." + ) + + def _generate_current_job_name(self, job_name=None): + """Generate the job name before running a processing job. + + Args: + job_name (str): Name of the processing job to be created. If not + specified, one is generated, using the base name given to the + constructor if applicable. + + Returns: + str: The supplied or generated job name. + """ + if job_name is not None: + return job_name + # Honor supplied base_job_name or derive from image_uri + if self.base_job_name: + base = self.base_job_name + else: + base = base_name_from_image( + self.image_uri, default_base_name="processing" + ) + # Replace invalid characters + base = re.sub(r"[^a-zA-Z0-9-]", "-", base) + return name_from_base(base) + + def _normalize_args(self, **kwargs): + """Normalize arguments for processing job.""" + code = kwargs.get("code") + if code is not None and is_pipeline_variable(code): + if not (isinstance(code, str) and code.startswith("s3://")): + raise ValueError( + "code argument has to be a valid S3 URI when it is a pipeline variable" + ) + return kwargs + + def _normalize_inputs(self, inputs): + """Normalize and validate processing inputs. + + Args: + inputs (list): List of ProcessingInput objects. + + Returns: + list: Normalized list of ProcessingInput objects. + """ + from sagemaker.processing import ProcessingInput as PI + + if inputs is None: + return [] + + normalized = [] + for inp in inputs: + if not isinstance(inp, PI): + raise TypeError( + "Processing inputs must be provided as ProcessingInput objects." + ) + normalized.append(inp) + return normalized + + def _normalize_outputs(self, outputs): + """Normalize and validate processing outputs. + + Args: + outputs (list): List of ProcessingOutput objects. + + Returns: + list: Normalized list of ProcessingOutput objects. + """ + from sagemaker.processing import ProcessingOutput as PO + + if outputs is None: + return [] + + normalized = [] + for output in outputs: + if not isinstance(output, PO): + raise TypeError( + "Processing outputs must be provided as ProcessingOutput objects." + ) + + # If the output has a pipeline variable URI, skip normalization + if output.s3_output and is_pipeline_variable(output.s3_output.s3_uri): + normalized.append(output) + continue + + # Check if the URI is already an S3 URI - pass through unchanged + if output.s3_output and output.s3_output.s3_uri: + uri = output.s3_output.s3_uri + if uri.startswith("s3://"): + normalized.append(output) + continue + + # In local mode, preserve file:// URIs + if _is_local_mode(self.instance_type, self.sagemaker_session): + if uri.startswith("file://"): + normalized.append(output) + continue + + # For non-S3 URIs in non-local mode, generate an S3 path + # Check if we're in a pipeline context + pipeline_config = workflow_utilities._pipeline_config + if pipeline_config is not None: + normalized.append(output) + continue + + # Generate default S3 output path + default_bucket = self.sagemaker_session.default_bucket() + prefix = self.sagemaker_session.default_bucket_prefix + job_name = self._current_job_name + output_name = output.output_name or "output" + + if prefix: + s3_uri = f"s3://{default_bucket}/{prefix}/{job_name}/output/{output_name}" + else: + s3_uri = f"s3://{default_bucket}/{job_name}/output/{output_name}" + + # Create a new output with the generated S3 URI + from sagemaker.processing import ProcessingS3Output, ProcessingOutput + new_s3_output = ProcessingS3Output( + s3_uri=s3_uri, + local_path=output.s3_output.local_path, + s3_upload_mode=output.s3_output.s3_upload_mode, + ) + new_output = ProcessingOutput( + output_name=output.output_name, + s3_output=new_s3_output, + ) + normalized.append(new_output) + else: + normalized.append(output) + + return normalized + + def _get_process_args(self, inputs, outputs, kms_key): + """Get processing job arguments.""" + app_specification = {"ImageUri": self.image_uri} + if self.entrypoint: + app_specification["ContainerEntrypoint"] = self.entrypoint + if self.arguments: + app_specification["ContainerArguments"] = self.arguments + + resources = { + "ClusterConfig": { + "InstanceCount": self.instance_count, + "InstanceType": self.instance_type, + "VolumeSizeInGB": self.volume_size_in_gb, + } + } + + if self.volume_kms_key: + resources["ClusterConfig"]["VolumeKmsKeyId"] = self.volume_kms_key + + stopping_condition = None + if self.max_runtime_in_seconds: + stopping_condition = {"MaxRuntimeInSeconds": self.max_runtime_in_seconds} + + network_config = None + if self.network_config: + network_config = self.network_config._to_request_dict() if hasattr( + self.network_config, '_to_request_dict' + ) else self.network_config + + role_arn = self.role + if role_arn and not is_pipeline_variable(role_arn): + role_arn = self.sagemaker_session.expand_role(role_arn) + + return { + "job_name": self._current_job_name, + "inputs": inputs, + "output_config": {"Outputs": outputs}, + "resources": resources, + "stopping_condition": stopping_condition, + "app_specification": app_specification, + "environment": self.env, + "network_config": network_config, + "role_arn": role_arn, + "tags": self.tags or [], + } + + def _start_new(self, inputs, outputs, kms_key): + """Start a new processing job.""" + process_args = self._get_process_args(inputs, outputs, kms_key) + + if hasattr(self.sagemaker_session, '_intercept_create_request'): + if serialize is not None and codec is not None: + serialized = serialize(process_args) + transformed = codec.transform(serialized) + # Remove tags before creating ProcessingJob + transformed.pop("tags", None) + from sagemaker.processing import ProcessingJob + job = ProcessingJob(sagemaker_session=self.sagemaker_session, **transformed) + return job + + return None + + @runnable_by_pipeline + def run( + self, + inputs=None, + outputs=None, + arguments=None, + wait=True, + logs=True, + job_name=None, + experiment_config=None, + kms_key=None, + code=None, + ): + """Runs a processing job. + + Args: + inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files + for the processing job. + outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs + for the processing job. + arguments (list[str]): A list of string arguments to be passed to a + processing job. + wait (bool): Whether the call should wait until the job completes + (default: True). + logs (bool): Whether to show the logs produced by the job (default: True). + job_name (str): Processing job name. + experiment_config (dict[str, str]): Experiment management configuration. + kms_key (str): The ARN of the KMS key. + code (str): The S3 URI or local path to the code file. + """ + if logs and not wait: + raise ValueError("Logs can only be shown if wait is set to True.") + + if arguments: + self.arguments = arguments + + self._current_job_name = self._generate_current_job_name(job_name) + + normalized_inputs = self._normalize_inputs(inputs) + normalized_outputs = self._normalize_outputs(outputs) + + job = self._start_new(normalized_inputs, normalized_outputs, kms_key) + + self.latest_job = job + self.jobs.append(job) + + if wait and job is not None: + job.wait(logs=logs) + + return job + + +class ScriptProcessor(Processor): + """Handles Amazon SageMaker processing tasks for jobs using a machine learning framework.""" + + def __init__( + self, + role=None, + image_uri=None, + command=None, + instance_count=None, + instance_type=None, + volume_size_in_gb=30, + volume_kms_key=None, + output_kms_key=None, + max_runtime_in_seconds=None, + base_job_name=None, + sagemaker_session=None, + env=None, + tags: Tags = None, + network_config=None, + ): + """Initializes a ``ScriptProcessor`` instance. + + Args: + role (str): An AWS IAM role name or ARN. Optional in local mode. + image_uri (str): The URI of the Docker image. + command (list[str]): The command to run (default: ["python3"]). + instance_count (int): The number of instances. + instance_type (str): The type of ML compute instance. + volume_size_in_gb (int): Size in GB of the EBS volume (default: 30). + volume_kms_key (str): A KMS key for the processing volume. + output_kms_key (str): The KMS key ID for outputs. + max_runtime_in_seconds (int): Timeout in seconds. + base_job_name (str): Prefix for processing job name. + sagemaker_session: Session object. + env (dict): Environment variables. + tags (Optional[Tags]): Tags for the processing job. + network_config: Network configuration. + """ + super().__init__( + role=role, + image_uri=image_uri, + instance_count=instance_count, + instance_type=instance_type, + volume_size_in_gb=volume_size_in_gb, + volume_kms_key=volume_kms_key, + output_kms_key=output_kms_key, + max_runtime_in_seconds=max_runtime_in_seconds, + base_job_name=base_job_name, + sagemaker_session=sagemaker_session, + env=env, + tags=tags, + network_config=network_config, + ) + self.command = command or ["python3"] + + def _get_user_code_name(self, code): + """Get the user code filename from a path or S3 URI.""" + return os.path.basename(code) + + def _handle_user_code_url(self, code): + """Handle user code URL - upload local files to S3.""" + if code.startswith("s3://"): + return code + + parsed = urllib.parse.urlparse(code) + if parsed.scheme and parsed.scheme not in ("", "file"): + raise ValueError( + f"code url scheme {parsed.scheme} is not recognized. " + "Please use a local file path or S3 URI." + ) + + # Local file + if not os.path.exists(code): + raise ValueError(f"code file {code} wasn't found. Please provide a valid file path.") + + if os.path.isdir(code): + raise ValueError(f"code {code} must be a file, not a directory.") + + # Upload to S3 + desired_s3_uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix or "", + self._current_job_name, + "input", + "code", + os.path.basename(code), + ) + return s3.S3Uploader.upload( + local_path=code, + desired_s3_uri=desired_s3_uri, + sagemaker_session=self.sagemaker_session, + ) + + def _upload_code(self, code): + """Upload code to S3.""" + pipeline_config = workflow_utilities._pipeline_config + if pipeline_config is not None: + desired_s3_uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix or "", + pipeline_config.pipeline_name, + "code", + pipeline_config.code_hash if hasattr(pipeline_config, 'code_hash') else "", + os.path.basename(code), + ) + else: + desired_s3_uri = s3.s3_path_join( + "s3://", + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix or "", + self._current_job_name, + "input", + "code", + os.path.basename(code), + ) + + return s3.S3Uploader.upload( + local_path=code, + desired_s3_uri=desired_s3_uri, + sagemaker_session=self.sagemaker_session, + ) + + def _convert_code_and_add_to_inputs(self, inputs, s3_uri): + """Convert code S3 URI to a processing input and add to inputs list.""" + from sagemaker.processing import ProcessingInput, ProcessingS3Input + + code_input = ProcessingInput( + input_name="code", + s3_input=ProcessingS3Input( + s3_uri=s3_uri, + local_path="/opt/ml/processing/input/code", + s3_data_type="S3Prefix", + s3_input_mode="File", + ), + ) + return inputs + [code_input] + + def _set_entrypoint(self, command, user_script_name): + """Set the entrypoint for the processing container.""" + self.entrypoint = command + [ + os.path.join("/opt/ml/processing/input/code", user_script_name) + ] + + def run( + self, + code=None, + inputs=None, + outputs=None, + arguments=None, + wait=True, + logs=True, + job_name=None, + experiment_config=None, + kms_key=None, + ): + """Runs a processing job. + + Args: + code (str): The S3 URI or local path to the code file. + inputs (list): Input files for the processing job. + outputs (list): Outputs for the processing job. + arguments (list[str]): Arguments for the processing job. + wait (bool): Whether to wait for completion (default: True). + logs (bool): Whether to show logs (default: True). + job_name (str): Processing job name. + experiment_config (dict): Experiment configuration. + kms_key (str): The ARN of the KMS key. + """ + if logs and not wait: + raise ValueError("Logs can only be shown if wait is set to True.") + + if arguments: + self.arguments = arguments + + self._current_job_name = self._generate_current_job_name(job_name) + + # Handle code + if code: + s3_uri = self._handle_user_code_url(code) + user_code_name = self._get_user_code_name(code) + self._set_entrypoint(self.command, user_code_name) + inputs = inputs or [] + inputs = self._convert_code_and_add_to_inputs(inputs, s3_uri) + + normalized_inputs = self._normalize_inputs(inputs) + normalized_outputs = self._normalize_outputs(outputs) + + job = self._start_new(normalized_inputs, normalized_outputs, kms_key) + + self.latest_job = job + self.jobs.append(job) + + if wait and job is not None: + job.wait(logs=logs) + + return job + + +class FrameworkProcessor(ScriptProcessor): + """Handles Amazon SageMaker processing tasks for jobs using ML frameworks.""" + + def __init__( + self, + role=None, + image_uri=None, + command=None, + instance_count=None, + instance_type=None, + volume_size_in_gb=30, + volume_kms_key=None, + output_kms_key=None, + max_runtime_in_seconds=None, + base_job_name=None, + sagemaker_session=None, + env=None, + tags: Tags = None, + network_config=None, + code_location=None, + ): + """Initializes a ``FrameworkProcessor`` instance. + + Args: + role (str): An AWS IAM role name or ARN. Optional in local mode. + image_uri (str): The URI of the Docker image. + command (list[str]): The command to run (default: ["python"]). + instance_count (int): The number of instances. + instance_type (str): The type of ML compute instance. + volume_size_in_gb (int): Size in GB of the EBS volume (default: 30). + volume_kms_key (str): A KMS key for the processing volume. + output_kms_key (str): The KMS key ID for outputs. + max_runtime_in_seconds (int): Timeout in seconds. + base_job_name (str): Prefix for processing job name. + sagemaker_session: Session object. + env (dict): Environment variables. + tags (Optional[Tags]): Tags for the processing job. + network_config: Network configuration. + code_location (str): S3 URI for code storage. + """ + super().__init__( + role=role, + image_uri=image_uri, + command=command or ["python"], + instance_count=instance_count, + instance_type=instance_type, + volume_size_in_gb=volume_size_in_gb, + volume_kms_key=volume_kms_key, + output_kms_key=output_kms_key, + max_runtime_in_seconds=max_runtime_in_seconds, + base_job_name=base_job_name, + sagemaker_session=sagemaker_session, + env=env, + tags=tags, + network_config=network_config, + ) + # Strip trailing slash from code_location + self.code_location = code_location.rstrip("/") if code_location else None + + def _generate_framework_script(self, entry_point): + """Generate the framework script content.""" + script = dedent(f"""\ + #!/bin/bash + + cd /opt/ml/processing/input/code + + # Install requirements if they exist + if [ -f requirements.txt ]; then + python install_requirements.py + fi + + # Run the entry point + {' '.join(self.command)} {entry_point} + """) + return script + + def _create_and_upload_runproc(self, entry_point, requirements, s3_uri): + """Create and upload the runproc.sh script.""" + script_content = self._generate_framework_script(entry_point) + return s3.S3Uploader.upload_string_as_file_body( + body=script_content, + desired_s3_uri=s3_uri, + sagemaker_session=self.sagemaker_session, + ) + + def _set_entrypoint(self, command, user_script_name): + """Set the entrypoint for the framework processing container.""" + self.entrypoint = [ + "/bin/bash", + os.path.join("/opt/ml/processing/input/code", user_script_name), + ] + + def _patch_inputs_with_payload(self, inputs, s3_uri): + """Patch inputs with the code payload.""" + from sagemaker.processing import ProcessingInput, ProcessingS3Input + + code_input = ProcessingInput( + input_name="code", + s3_input=ProcessingS3Input( + s3_uri=s3_uri, + local_path="/opt/ml/processing/input/code", + s3_data_type="S3Prefix", + s3_input_mode="File", + ), + ) + return inputs + [code_input] + + def _package_code(self, entry_point, source_dir, requirements, job_name, kms_key): + """Package code into a tar.gz and upload to S3.""" + if source_dir and not os.path.exists(source_dir): + raise ValueError(f"source_dir does not exist: {source_dir}") + + if source_dir is None: + source_dir = os.path.dirname(os.path.abspath(entry_point)) + + # Determine S3 destination + if self.code_location: + s3_prefix = f"{self.code_location}/{job_name}/source" + else: + bucket = self.sagemaker_session.default_bucket() + prefix = self.sagemaker_session.default_bucket_prefix or "" + if prefix: + s3_prefix = f"s3://{bucket}/{prefix}/{job_name}/source" + else: + s3_prefix = f"s3://{bucket}/{job_name}/source" + + return f"{s3_prefix}/sourcedir.tar.gz" + + def _pack_and_upload_code(self, code, source_dir, requirements, job_name, inputs, kms_key): + """Pack and upload code, returning the S3 URI and updated inputs.""" + if code.startswith("s3://"): + return code, inputs or [], job_name + + self._current_job_name = self._generate_current_job_name(job_name) + + # Package the code + payload_s3_uri = self._package_code( + entry_point=code, + source_dir=source_dir, + requirements=requirements, + job_name=self._current_job_name, + kms_key=kms_key, + ) + + # Determine runproc.sh location + if self.code_location: + runproc_s3_uri = f"{self.code_location}/{self._current_job_name}/source/runproc.sh" + else: + bucket = self.sagemaker_session.default_bucket() + prefix = self.sagemaker_session.default_bucket_prefix or "" + if prefix: + runproc_s3_uri = f"s3://{bucket}/{prefix}/{self._current_job_name}/source/runproc.sh" + else: + runproc_s3_uri = f"s3://{bucket}/{self._current_job_name}/source/runproc.sh" + + # Upload install_requirements.py + install_req_s3_uri = runproc_s3_uri.replace("runproc.sh", "install_requirements.py") + install_req_content = "import subprocess\nimport sys\nsubprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])" + s3.S3Uploader.upload_string_as_file_body( + body=install_req_content, + desired_s3_uri=install_req_s3_uri, + sagemaker_session=self.sagemaker_session, + ) + + # Create and upload runproc.sh + entry_point_name = os.path.basename(code) + uploaded_uri = self._create_and_upload_runproc( + entry_point_name, requirements, runproc_s3_uri + ) + + # Patch inputs with the code payload + inputs = inputs or [] + inputs = self._patch_inputs_with_payload(inputs, payload_s3_uri) + + return uploaded_uri, inputs, self._current_job_name + + def run( + self, + code=None, + inputs=None, + outputs=None, + arguments=None, + wait=True, + logs=True, + job_name=None, + experiment_config=None, + kms_key=None, + source_dir=None, + requirements=None, + ): + """Runs a processing job. + + Args: + code (str): The S3 URI or local path to the code file. + inputs (list): Input files for the processing job. + outputs (list): Outputs for the processing job. + arguments (list[str]): Arguments for the processing job. + wait (bool): Whether to wait for completion (default: True). + logs (bool): Whether to show logs (default: True). + job_name (str): Processing job name. + experiment_config (dict): Experiment configuration. + kms_key (str): The ARN of the KMS key. + source_dir (str): Path to source directory. + requirements (str): Path to requirements file. + """ + if logs and not wait: + raise ValueError("Logs can only be shown if wait is set to True.") + + if arguments: + self.arguments = arguments + + self._current_job_name = self._generate_current_job_name(job_name) + + if code: + if code.startswith("s3://"): + # S3 code - use directly + user_code_name = self._get_user_code_name(code) + self._set_entrypoint(self.command, user_code_name) + else: + # Local code - pack and upload + uploaded_uri, inputs, _ = self._pack_and_upload_code( + code=code, + source_dir=source_dir, + requirements=requirements, + job_name=self._current_job_name, + inputs=inputs, + kms_key=kms_key, + ) + user_code_name = "runproc.sh" + self._set_entrypoint(self.command, user_code_name) + + normalized_inputs = self._normalize_inputs(inputs) + normalized_outputs = self._normalize_outputs(outputs) + + job = self._start_new(normalized_inputs, normalized_outputs, kms_key) + + self.latest_job = job + self.jobs.append(job) + + if wait and job is not None: + job.wait(logs=logs) + + return job + + +class ProcessingInput: + """Represents a processing input.""" + + def __init__( + self, + input_name=None, + s3_input=None, + dataset_definition=None, + app_managed=False, + ): + self.input_name = input_name + self.s3_input = s3_input + self.dataset_definition = dataset_definition + self.app_managed = app_managed + + +class ProcessingOutput: + """Represents a processing output.""" + + def __init__( + self, + output_name=None, + s3_output=None, + app_managed=False, + ): + self.output_name = output_name + self.s3_output = s3_output + self.app_managed = app_managed + + +class ProcessingS3Input: + """Represents an S3 input for processing.""" + + def __init__( + self, + s3_uri=None, + local_path=None, + s3_data_type=None, + s3_input_mode=None, + s3_data_distribution_type=None, + s3_compression_type=None, + ): + self.s3_uri = s3_uri + self.local_path = local_path + self.s3_data_type = s3_data_type + self.s3_input_mode = s3_input_mode + self.s3_data_distribution_type = s3_data_distribution_type + self.s3_compression_type = s3_compression_type + + +class ProcessingS3Output: + """Represents an S3 output for processing.""" + + def __init__( + self, + s3_uri=None, + local_path=None, + s3_upload_mode=None, + ): + self.s3_uri = s3_uri + self.local_path = local_path + self.s3_upload_mode = s3_upload_mode + + +class ProcessingJob: + """Represents a processing job.""" + + def __init__(self, sagemaker_session=None, **kwargs): + self.sagemaker_session = sagemaker_session + for key, value in kwargs.items(): + setattr(self, key, value) + + def wait(self, logs=True): + """Wait for the processing job to complete.""" + pass + + +def _processing_input_to_request_dict(processing_input): + """Convert a ProcessingInput to a request dictionary.""" + result = {"InputName": processing_input.input_name} + + if processing_input.s3_input: + s3_input = processing_input.s3_input + result["S3Input"] = { + "S3Uri": s3_input.s3_uri, + "LocalPath": s3_input.local_path, + "S3DataType": s3_input.s3_data_type, + "S3InputMode": s3_input.s3_input_mode, + } + if s3_input.s3_data_distribution_type: + result["S3Input"]["S3DataDistributionType"] = s3_input.s3_data_distribution_type + if s3_input.s3_compression_type: + result["S3Input"]["S3CompressionType"] = s3_input.s3_compression_type + + if processing_input.dataset_definition: + result["DatasetDefinition"] = processing_input.dataset_definition + + if processing_input.app_managed: + result["AppManaged"] = True + + return result + + +def _processing_output_to_request_dict(processing_output): + """Convert a ProcessingOutput to a request dictionary.""" + result = {"OutputName": processing_output.output_name} + + if processing_output.s3_output: + s3_output = processing_output.s3_output + result["S3Output"] = { + "S3Uri": s3_output.s3_uri, + "LocalPath": s3_output.local_path, + "S3UploadMode": s3_output.s3_upload_mode, + } + + if processing_output.app_managed: + result["AppManaged"] = True + + return result + + +def _get_process_request( + inputs, + output_config, + job_name, + resources, + stopping_condition, + app_specification, + environment, + network_config, + role_arn, + tags, + experiment_config=None, +): + """Build the processing job request dictionary.""" + request = { + "ProcessingJobName": job_name, + "AppSpecification": app_specification, + "RoleArn": role_arn, + "ProcessingResources": resources, + } + + if inputs: + request["ProcessingInputs"] = inputs + + if output_config: + request["ProcessingOutputConfig"] = output_config + + if stopping_condition: + request["StoppingCondition"] = stopping_condition + + if environment: + request["Environment"] = environment + + if network_config: + request["NetworkConfig"] = network_config + + if tags: + request["Tags"] = tags + + if experiment_config: + request["ExperimentConfig"] = experiment_config + + return request + + +def _wait_until(session, job_name, poll=5): + """Wait until a processing job completes.""" + pass + + +def _logs_init(session, job_name, log_group): + """Initialize log streaming.""" + return (1, [], {}, None, log_group, False, lambda x: x) + + +def _flush_log_streams(*args, **kwargs): + """Flush log streams.""" + pass + + +def _get_initial_job_state(description, log_state): + """Get initial job state.""" + return LogState.COMPLETE + + +def _check_job_status(*args, **kwargs): + """Check job status.""" + pass + + +def logs_for_processing_job(session, job_name, wait=True, poll=10): + """Display logs for a processing job.""" + pass diff --git a/tests/unit/test_processing_local_mode.py b/tests/unit/test_processing_local_mode.py new file mode 100644 index 0000000000..e49d8ad143 --- /dev/null +++ b/tests/unit/test_processing_local_mode.py @@ -0,0 +1,985 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import os +import tempfile + +import pytest +from unittest.mock import Mock, patch + +from sagemaker.processing import ( + Processor, + ScriptProcessor, + FrameworkProcessor, + ProcessingInput, + ProcessingOutput, + ProcessingS3Input, + ProcessingS3Output, + _processing_input_to_request_dict, + _processing_output_to_request_dict, + _get_process_request, + logs_for_processing_job, +) +from sagemaker.network import NetworkConfig + + +@pytest.fixture +def mock_session(): + session = Mock() + session.boto_session = Mock() + session.boto_session.region_name = "us-west-2" + session.sagemaker_client = Mock() + session.default_bucket = Mock(return_value="test-bucket") + session.default_bucket_prefix = "sagemaker" + session.expand_role = Mock(side_effect=lambda x: x) + session.sagemaker_config = {} + session.local_mode = False + return session + + +class TestProcessorNormalizeArgs: + def test_normalize_args_with_pipeline_variable_code(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + code_var = Mock() + with patch("sagemaker.processing.is_pipeline_variable", return_value=True): + with pytest.raises(ValueError, match="code argument has to be a valid S3 URI"): + processor._normalize_args(code=code_var) + + +class TestProcessorNormalizeInputs: + def test_normalize_inputs_with_dataset_definition(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + # Use a mock for dataset_definition since it's not part of the core fix + dataset_def = Mock() + inputs = [ProcessingInput(input_name="data", dataset_definition=dataset_def)] + + result = processor._normalize_inputs(inputs) + assert len(result) == 1 + assert result[0].dataset_definition == dataset_def + + def test_normalize_inputs_with_pipeline_variable_s3_uri(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + with patch("sagemaker.processing.is_pipeline_variable", return_value=True): + s3_input = ProcessingS3Input( + s3_uri="s3://bucket/input", + local_path="/opt/ml/processing/input", + s3_data_type="S3Prefix", + s3_input_mode="File", + ) + inputs = [ProcessingInput(input_name="input-1", s3_input=s3_input)] + + result = processor._normalize_inputs(inputs) + assert len(result) == 1 + + def test_normalize_inputs_invalid_type(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + with pytest.raises(TypeError, match="must be provided as ProcessingInput objects"): + processor._normalize_inputs(["invalid"]) + + +class TestProcessorNormalizeOutputs: + def test_normalize_outputs_with_pipeline_variable(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + with patch("sagemaker.processing.is_pipeline_variable", return_value=True): + s3_output = ProcessingS3Output( + s3_uri="s3://bucket/output", + local_path="/opt/ml/processing/output", + s3_upload_mode="EndOfJob", + ) + outputs = [ProcessingOutput(output_name="output-1", s3_output=s3_output)] + + result = processor._normalize_outputs(outputs) + assert len(result) == 1 + + def test_normalize_outputs_invalid_type(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + with pytest.raises(TypeError, match="must be provided as ProcessingOutput objects"): + processor._normalize_outputs(["invalid"]) + + +class TestFileUriPreservedInLocalMode: + """Tests that file:// URIs are preserved in local mode. + + **Validates: Requirements 1.1, 1.2, 2.1, 2.2** + + These tests verify the fix for issue #5562: file:// URIs should be + preserved when the session is in local mode, rather than being replaced + with auto-generated S3 paths. + """ + + @pytest.fixture + def local_mock_session(self): + session = Mock() + session.boto_session = Mock() + session.boto_session.region_name = "us-west-2" + session.sagemaker_client = Mock() + session.default_bucket = Mock(return_value="default-bucket") + session.default_bucket_prefix = "prefix" + session.expand_role = Mock(side_effect=lambda x: x) + session.sagemaker_config = {} + session.local_mode = True + return session + + @pytest.mark.parametrize( + "file_uri", + [ + "file:///tmp/output", + "file:///home/user/results", + "file:///data/processed", + ], + ) + def test_normalize_outputs_preserves_file_uri_in_local_mode(self, local_mock_session, file_uri): + """file:// URIs must be preserved when local_mode=True. + + The fix ensures that _normalize_outputs does not replace file:// URIs + with s3:// paths when the session is in local mode. + """ + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="local", + sagemaker_session=local_mock_session, + ) + processor._current_job_name = "test-job" + + s3_output = ProcessingS3Output( + s3_uri=file_uri, + local_path="/opt/ml/processing/output", + s3_upload_mode="EndOfJob", + ) + outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)] + + with patch("sagemaker.processing.workflow_utilities._pipeline_config", None): + result = processor._normalize_outputs(outputs) + + assert len(result) == 1 + assert result[0].s3_output.s3_uri == file_uri, ( + f"Expected file:// URI to be preserved as '{file_uri}' in local mode, " + f"but got '{result[0].s3_output.s3_uri}'" + ) + + +class TestPreservationNonLocalFileBehavior: + """Preservation property tests: Non-local-file behavior must remain unchanged. + + **Validates: Requirements 3.1, 3.2, 3.3, 3.4** + + These tests capture baseline behavior. They MUST PASS on both + unfixed and fixed code, confirming no regressions are introduced by the fix. + """ + + @pytest.fixture + def session_local_mode_true(self): + session = Mock() + session.boto_session = Mock() + session.boto_session.region_name = "us-west-2" + session.sagemaker_client = Mock() + session.default_bucket = Mock(return_value="default-bucket") + session.default_bucket_prefix = "prefix" + session.expand_role = Mock(side_effect=lambda x: x) + session.sagemaker_config = {} + session.local_mode = True + return session + + @pytest.fixture + def session_local_mode_false(self): + session = Mock() + session.boto_session = Mock() + session.boto_session.region_name = "us-west-2" + session.sagemaker_client = Mock() + session.default_bucket = Mock(return_value="default-bucket") + session.default_bucket_prefix = "prefix" + session.expand_role = Mock(side_effect=lambda x: x) + session.sagemaker_config = {} + session.local_mode = False + return session + + def _make_processor(self, session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=session, + ) + processor._current_job_name = "test-job" + return processor + + # --- Requirement 3.1: S3 URIs pass through unchanged regardless of local_mode --- + + @pytest.mark.parametrize( + "s3_uri,local_mode_fixture", + [ + ("s3://my-bucket/path", "session_local_mode_true"), + ("s3://my-bucket/path", "session_local_mode_false"), + ("s3://another-bucket/deep/nested/path", "session_local_mode_true"), + ("s3://another-bucket/deep/nested/path", "session_local_mode_false"), + ], + ) + def test_s3_uri_preserved_regardless_of_local_mode(self, s3_uri, local_mode_fixture, request): + """S3 URIs must pass through unchanged regardless of local_mode setting. + + **Validates: Requirements 3.1** + """ + session = request.getfixturevalue(local_mode_fixture) + processor = self._make_processor(session) + + s3_output = ProcessingS3Output( + s3_uri=s3_uri, + local_path="/opt/ml/processing/output", + s3_upload_mode="EndOfJob", + ) + outputs = [ProcessingOutput(output_name="my-output", s3_output=s3_output)] + + with patch("sagemaker.processing.workflow_utilities._pipeline_config", None): + result = processor._normalize_outputs(outputs) + + assert len(result) == 1 + assert result[0].s3_output.s3_uri == s3_uri + + # --- Requirement 3.2: Non-S3 URIs with local_mode=False replaced with S3 paths --- + + @pytest.mark.parametrize( + "non_s3_uri", + [ + "/local/output/path", + "http://example.com/output", + "ftp://server/output", + ], + ) + def test_non_s3_uri_replaced_when_not_local_mode(self, non_s3_uri, session_local_mode_false): + """Non-S3 URIs in non-local sessions are replaced with auto-generated S3 paths. + + **Validates: Requirements 3.2** + """ + processor = self._make_processor(session_local_mode_false) + + s3_output = ProcessingS3Output( + s3_uri=non_s3_uri, + local_path="/opt/ml/processing/output", + s3_upload_mode="EndOfJob", + ) + outputs = [ProcessingOutput(output_name="output-1", s3_output=s3_output)] + + with patch("sagemaker.processing.workflow_utilities._pipeline_config", None): + result = processor._normalize_outputs(outputs) + + assert len(result) == 1 + assert result[0].s3_output.s3_uri.startswith("s3://default-bucket/") + + # --- Requirement 3.3: Pipeline variable URIs skip normalization --- + + def test_pipeline_variable_uri_skips_normalization(self, session_local_mode_false): + """Pipeline variable URIs skip normalization entirely. + + **Validates: Requirements 3.3** + """ + processor = self._make_processor(session_local_mode_false) + + s3_output = ProcessingS3Output( + s3_uri="s3://bucket/output", + local_path="/opt/ml/processing/output", + s3_upload_mode="EndOfJob", + ) + outputs = [ProcessingOutput(output_name="output-1", s3_output=s3_output)] + + with patch("sagemaker.processing.is_pipeline_variable", return_value=True): + result = processor._normalize_outputs(outputs) + + assert len(result) == 1 + # Pipeline variable outputs are appended as-is without URI modification + assert result[0].s3_output.s3_uri == "s3://bucket/output" + + # --- Requirement 3.4: Non-ProcessingOutput objects raise TypeError --- + + @pytest.mark.parametrize( + "invalid_output", + [ + ["a string"], + [42], + [{"key": "value"}], + ], + ) + def test_non_processing_output_raises_type_error(self, invalid_output, session_local_mode_false): + """Non-ProcessingOutput objects must raise TypeError. + + **Validates: Requirements 3.4** + """ + processor = self._make_processor(session_local_mode_false) + + with pytest.raises(TypeError, match="must be provided as ProcessingOutput objects"): + processor._normalize_outputs(invalid_output) + + # --- Output name auto-generation --- + + def test_multiple_outputs_with_s3_uris_preserved(self, session_local_mode_false): + """Multiple outputs with S3 URIs are all preserved unchanged. + + **Validates: Requirements 3.1, 3.2** + """ + processor = self._make_processor(session_local_mode_false) + + outputs = [ + ProcessingOutput( + output_name="first-output", + s3_output=ProcessingS3Output( + s3_uri="s3://my-bucket/first", + local_path="/opt/ml/processing/output1", + s3_upload_mode="EndOfJob", + ), + ), + ProcessingOutput( + output_name="second-output", + s3_output=ProcessingS3Output( + s3_uri="s3://my-bucket/second", + local_path="/opt/ml/processing/output2", + s3_upload_mode="EndOfJob", + ), + ), + ] + + with patch("sagemaker.processing.workflow_utilities._pipeline_config", None): + result = processor._normalize_outputs(outputs) + + assert len(result) == 2 + assert result[0].output_name == "first-output" + assert result[1].output_name == "second-output" + # S3 URIs should be preserved since they already have s3:// scheme + assert result[0].s3_output.s3_uri == "s3://my-bucket/first" + assert result[1].s3_output.s3_uri == "s3://my-bucket/second" + + +class TestProcessorLocalModeRole: + """Tests for local mode role validation behavior. + + The implementation checks whether the session is in local mode (via + session.local_mode attribute) OR the instance_type starts with 'local'. + If either condition is true, role is not required. + """ + + def _make_local_mock_session(self): + """Create a mock session that simulates local mode.""" + mock_local_session = Mock() + mock_local_session.boto_session = Mock() + mock_local_session.boto_session.region_name = "us-west-2" + mock_local_session.sagemaker_client = Mock() + mock_local_session.default_bucket = Mock(return_value="test-bucket") + mock_local_session.default_bucket_prefix = "sagemaker" + mock_local_session.expand_role = Mock(side_effect=lambda x: x) + mock_local_session.sagemaker_config = {} + mock_local_session.local_mode = True + return mock_local_session + + def test_processor_init_without_role_in_local_mode_no_error(self): + """Processor with instance_type='local' and no role should not raise.""" + mock_local_session = self._make_local_mock_session() + + processor = Processor( + image_uri="test-image:latest", + instance_count=1, + instance_type="local", + sagemaker_session=mock_local_session, + ) + assert processor.role is None + assert processor.instance_type == "local" + + def test_processor_init_without_role_in_local_gpu_mode_no_error(self): + """Processor with instance_type='local_gpu' and no role should not raise.""" + mock_local_session = self._make_local_mock_session() + + processor = Processor( + image_uri="test-image:latest", + instance_count=1, + instance_type="local_gpu", + sagemaker_session=mock_local_session, + ) + assert processor.role is None + assert processor.instance_type == "local_gpu" + + def test_processor_init_without_role_with_local_session_no_error(self): + """Processor with session.local_mode=True and no role should not raise. + + This tests the case where the session is in local mode (e.g., created + externally as a LocalSession) but the instance_type is a cloud type. + The _is_local_mode() helper checks session.local_mode in addition to + instance_type, so role is not required when the session itself indicates + local mode. This supports use cases where a LocalSession is passed in + with a non-local instance_type for testing purposes. + """ + mock_local_session = self._make_local_mock_session() + + processor = Processor( + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_local_session, + ) + assert processor.role is None + + def test_processor_init_without_role_non_local_raises_error(self): + """Processor with instance_type='ml.m5.xlarge' and no role should still raise ValueError.""" + mock_session = Mock() + mock_session.boto_session = Mock() + mock_session.boto_session.region_name = "us-west-2" + mock_session.sagemaker_client = Mock() + mock_session.default_bucket = Mock(return_value="test-bucket") + mock_session.default_bucket_prefix = "sagemaker" + mock_session.expand_role = Mock(side_effect=lambda x: x) + mock_session.sagemaker_config = {} + mock_session.local_mode = False + + with pytest.raises(ValueError, match="AWS IAM role is required"): + Processor( + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + def test_processor_init_with_role_in_local_mode_still_works(self): + """Processor with instance_type='local' and a valid role should still work fine.""" + mock_local_session = self._make_local_mock_session() + + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="local", + sagemaker_session=mock_local_session, + ) + assert processor.role == "arn:aws:iam::123456789012:role/SageMakerRole" + assert processor.instance_type == "local" + + +class TestProcessorGetProcessArgs: + def test_get_process_args_with_stopping_condition(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + max_runtime_in_seconds=3600, + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + args = processor._get_process_args([], [], None) + assert args["stopping_condition"]["MaxRuntimeInSeconds"] == 3600 + + def test_get_process_args_without_stopping_condition(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + args = processor._get_process_args([], [], None) + assert args["stopping_condition"] is None + + def test_get_process_args_with_arguments(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + processor.arguments = ["--arg1", "value1"] + + args = processor._get_process_args([], [], None) + assert args["app_specification"]["ContainerArguments"] == ["--arg1", "value1"] + + def test_get_process_args_with_entrypoint(self, mock_session): + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + entrypoint=["python", "script.py"], + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + args = processor._get_process_args([], [], None) + assert args["app_specification"]["ContainerEntrypoint"] == ["python", "script.py"] + + def test_get_process_args_with_network_config(self, mock_session): + network_config = NetworkConfig(enable_network_isolation=True) + + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + network_config=network_config, + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + args = processor._get_process_args([], [], None) + assert args["network_config"] is not None + + +class TestScriptProcessor: + def test_init_with_sklearn_image(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="sklearn:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + assert processor.command == ["python3"] + + def test_get_user_code_name(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + result = processor._get_user_code_name("s3://bucket/path/script.py") + assert result == "script.py" + + def test_handle_user_code_url_s3(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + result = processor._handle_user_code_url("s3://bucket/script.py") + assert result == "s3://bucket/script.py" + + def test_handle_user_code_url_local_file(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._current_job_name = "test-job" + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as f: + f.write("print('test')") + temp_file = f.name + + try: + with patch("sagemaker.s3.S3Uploader.upload", return_value="s3://bucket/script.py"): + result = processor._handle_user_code_url(temp_file) + assert result == "s3://bucket/script.py" + finally: + os.unlink(temp_file) + + def test_handle_user_code_url_file_not_found(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + with pytest.raises(ValueError, match="wasn't found"): + processor._handle_user_code_url("/nonexistent/file.py") + + def test_handle_user_code_url_directory(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError, match="must be a file"): + processor._handle_user_code_url(tmpdir) + + def test_handle_user_code_url_invalid_scheme(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + with pytest.raises(ValueError, match="url scheme .* is not recognized"): + processor._handle_user_code_url("http://example.com/script.py") + + def test_convert_code_and_add_to_inputs(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + inputs = [] + result = processor._convert_code_and_add_to_inputs(inputs, "s3://bucket/code.py") + + assert len(result) == 1 + assert result[0].input_name == "code" + + def test_set_entrypoint(self, mock_session): + processor = ScriptProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + processor._set_entrypoint(["python3"], "script.py") + assert processor.entrypoint[-1].endswith("script.py") + + +class TestFrameworkProcessor: + def test_init_default_command(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + assert processor.command == ["python"] + + def test_init_with_code_location(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + code_location="s3://bucket/code/", + sagemaker_session=mock_session, + ) + assert processor.code_location == "s3://bucket/code" + + def test_patch_inputs_with_payload(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + inputs = [] + result = processor._patch_inputs_with_payload(inputs, "s3://bucket/code/sourcedir.tar.gz") + + assert len(result) == 1 + assert result[0].input_name == "code" + + def test_set_entrypoint_framework(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + processor._set_entrypoint(["python"], "runproc.sh") + assert processor.entrypoint[0] == "/bin/bash" + + def test_generate_framework_script(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + command=["python3"], + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + script = processor._generate_framework_script("train.py") + assert "#!/bin/bash" in script + assert "train.py" in script + assert "python3" in script + assert "install_requirements.py" in script + + def test_create_and_upload_runproc(self, mock_session): + processor = FrameworkProcessor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + with patch( + "sagemaker.s3.S3Uploader.upload_string_as_file_body", + return_value="s3://bucket/runproc.sh", + ): + result = processor._create_and_upload_runproc( + "train.py", None, "s3://bucket/runproc.sh" + ) + assert result == "s3://bucket/runproc.sh" + + +class TestHelperFunctions: + def test_processing_input_to_request_dict(self): + s3_input = ProcessingS3Input( + s3_uri="s3://bucket/input", + local_path="/opt/ml/processing/input", + s3_data_type="S3Prefix", + s3_input_mode="File", + ) + processing_input = ProcessingInput(input_name="data", s3_input=s3_input) + + result = _processing_input_to_request_dict(processing_input) + + assert result["InputName"] == "data" + assert result["S3Input"]["S3Uri"] == "s3://bucket/input" + + def test_processing_output_to_request_dict(self): + s3_output = ProcessingS3Output( + s3_uri="s3://bucket/output", + local_path="/opt/ml/processing/output", + s3_upload_mode="EndOfJob", + ) + processing_output = ProcessingOutput(output_name="results", s3_output=s3_output) + + result = _processing_output_to_request_dict(processing_output) + + assert result["OutputName"] == "results" + assert result["S3Output"]["S3Uri"] == "s3://bucket/output" + + def test_get_process_request_minimal(self): + result = _get_process_request( + inputs=[], + output_config={"Outputs": []}, + job_name="test-job", + resources={"ClusterConfig": {}}, + stopping_condition=None, + app_specification={"ImageUri": "test-image"}, + environment=None, + network_config=None, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + tags=None, + ) + + assert result["ProcessingJobName"] == "test-job" + assert result["RoleArn"] == "arn:aws:iam::123456789012:role/SageMakerRole" + + def test_get_process_request_with_all_params(self): + result = _get_process_request( + inputs=[{"InputName": "data"}], + output_config={"Outputs": [{"OutputName": "results"}]}, + job_name="test-job", + resources={"ClusterConfig": {}}, + stopping_condition={"MaxRuntimeInSeconds": 3600}, + app_specification={"ImageUri": "test-image"}, + environment={"KEY": "VALUE"}, + network_config={"EnableNetworkIsolation": True}, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + tags=[{"Key": "Project", "Value": "ML"}], + experiment_config={"ExperimentName": "test-exp"}, + ) + + assert result["ProcessingInputs"] == [{"InputName": "data"}] + assert result["Environment"] == {"KEY": "VALUE"} + assert result["ExperimentConfig"] == {"ExperimentName": "test-exp"} + + +class TestProcessingInputOutputHelpers: + def test_processing_input_with_app_managed(self): + s3_input = ProcessingS3Input( + s3_uri="s3://bucket/input", + local_path="/opt/ml/processing/input", + s3_data_type="S3Prefix", + s3_input_mode="File", + ) + processing_input = ProcessingInput(input_name="data", s3_input=s3_input, app_managed=True) + + result = _processing_input_to_request_dict(processing_input) + + assert result["AppManaged"] is True + + def test_processing_output_with_app_managed(self): + s3_output = ProcessingS3Output( + s3_uri="s3://bucket/output", + local_path="/opt/ml/processing/output", + s3_upload_mode="EndOfJob", + ) + processing_output = ProcessingOutput( + output_name="results", s3_output=s3_output, app_managed=True + ) + + result = _processing_output_to_request_dict(processing_output) + + assert result["AppManaged"] is True + + +class TestProcessorBasics: + """Test cases for basic Processor functionality""" + + def test_init_with_minimal_params(self, mock_session): + """Test initialization with minimal parameters""" + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + assert processor.role == "arn:aws:iam::123456789012:role/SageMakerRole" + assert processor.image_uri == "test-image:latest" + assert processor.instance_count == 1 + assert processor.instance_type == "ml.m5.xlarge" + assert processor.volume_size_in_gb == 30 + + def test_init_with_all_params(self, mock_session): + """Test initialization with all parameters""" + network_config = NetworkConfig( + enable_network_isolation=True, security_group_ids=["sg-123"], subnets=["subnet-123"] + ) + + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=2, + instance_type="ml.m5.2xlarge", + entrypoint=["python", "script.py"], + volume_size_in_gb=50, + volume_kms_key="kms-key-123", + output_kms_key="output-kms-key", + max_runtime_in_seconds=7200, + base_job_name="test-processor", + sagemaker_session=mock_session, + env={"KEY": "VALUE"}, + tags=[("Project", "ML")], + network_config=network_config, + ) + + assert processor.instance_count == 2 + assert processor.volume_size_in_gb == 50 + assert processor.entrypoint == ["python", "script.py"] + assert processor.env == {"KEY": "VALUE"} + assert processor.network_config == network_config + + def test_init_without_role_raises_error(self, mock_session): + """Test initialization without role raises ValueError in non-local mode""" + with pytest.raises(ValueError, match="AWS IAM role is required"): + Processor( + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + def test_init_with_local_instance_type(self): + """Test initialization with local instance type creates a LocalSession.""" + with patch("sagemaker.processing.LocalSession") as mock_local_session_cls: + mock_local_session = Mock() + mock_local_session.local_mode = True + mock_local_session.boto_session = Mock() + mock_local_session.boto_session.region_name = "us-west-2" + mock_local_session.sagemaker_client = Mock() + mock_local_session.default_bucket = Mock(return_value="test-bucket") + mock_local_session.default_bucket_prefix = "sagemaker" + mock_local_session.expand_role = Mock(side_effect=lambda x: x) + mock_local_session.sagemaker_config = {} + mock_local_session_cls.return_value = mock_local_session + + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="local", + ) + + mock_local_session_cls.assert_called_once() + assert processor.sagemaker_session == mock_local_session + + def test_run_with_logs_but_no_wait_raises_error(self, mock_session): + """Test run with logs=True but wait=False raises ValueError""" + processor = Processor( + role="arn:aws:iam::123456789012:role/SageMakerRole", + image_uri="test-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + with pytest.raises(ValueError, match="Logs can only be shown if wait is set to True"): + processor.run(wait=False, logs=True)