Skip to content

Commit

Permalink
Add EMR on EKS support (#38)
Browse files Browse the repository at this point in the history
* Require secure transport on buckets
* EMR on EKS support
  • Loading branch information
dacort authored Jan 12, 2024
1 parent 4fae1d2 commit 1dbf3aa
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 67 deletions.
4 changes: 3 additions & 1 deletion src/emr_cli/deployments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@ class SparkParams:
Spark `--conf` parameters based on the environment being deployed to.
"""

SUPPORTED_ENVIRONMENTS = ["emr_serverless", "emr_ec2"]
SUPPORTED_ENVIRONMENTS = ["emr_serverless", "emr_ec2", "emr_eks"]

def __init__(
self,
common_params: Optional[Dict[str, str]] = None,
emr_serverless_params: Optional[Dict[str, str]] = None,
emr_ec2_params: Optional[Dict[str, str]] = None,
emr_eks_params: Optional[Dict[str, str]] = None,
) -> None:
self._common = common_params or {}
self._environment_params = {
"emr_serverless": emr_serverless_params or {},
"emr_ec2": emr_ec2_params or {},
"emr_eks": emr_eks_params or {},
}

def params_for(self, deployment_type: str) -> str:
Expand Down
20 changes: 19 additions & 1 deletion src/emr_cli/deployments/emr_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,24 @@ def _create_s3_buckets(self):
CreateBucketConfiguration={"LocationConstraint": self.s3_client.meta.region_name},
)
console_log(f"Created S3 bucket: s3://{bucket_name}")
self.s3_client.put_bucket_policy(Bucket=bucket_name, Policy=self._default_s3_bucket_policy(bucket_name))

def _default_s3_bucket_policy(self, bucket_name) -> str:
bucket_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "RequireSecureTransport",
"Effect": "Deny",
"Principal": "*",
"Resource": [f"arn:aws:s3:::{bucket_name}/*", f"arn:aws:s3:::{bucket_name}"],
"Condition": {
"Bool": {"aws:SecureTransport": "false", "aws:SourceArn": f"arn:aws:s3:::{bucket_name} "}
},
}
],
}
return json.dumps(bucket_policy)

def _create_service_role(self):
"""
Expand Down Expand Up @@ -329,7 +347,7 @@ def run_job(

if spark_submit_opts:
spark_submit_params = f"{spark_submit_params} {spark_submit_opts}".strip()

# Escape job args if they're provided
if job_args:
job_args = [shlex.quote(arg) for arg in job_args]
Expand Down
122 changes: 122 additions & 0 deletions src/emr_cli/deployments/emr_eks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import re
import sys
from os.path import join
from time import sleep
from typing import List, Optional

import boto3
from emr_cli.deployments.emr_serverless import DeploymentPackage
from emr_cli.utils import console_log, print_s3_gz


class EMREKS:
def __init__(
self, virtual_cluster_id: str, job_role: str, deployment_package: DeploymentPackage, region: str = ""
) -> None:
self.virtual_cluster_id = virtual_cluster_id
self.job_role = job_role
self.dp = deployment_package
self.s3_client = boto3.client("s3")
if region:
self.client = boto3.client("emr-containers", region_name=region)
else:
# Note that boto3 uses AWS_DEFAULT_REGION, not AWS_REGION
# We may want to add an extra check here for the latter.
self.client = boto3.client("emr-containers")

def run_job(
self,
job_name: str,
job_args: Optional[List[str]] = None,
spark_submit_opts: Optional[str] = None,
wait: bool = True,
show_logs: bool = False,
s3_logs_uri: Optional[str] = None,
):
if show_logs and not s3_logs_uri:
raise RuntimeError("--show-stdout requires --s3-logs-uri to be set.")

# If job_name is the default, just replace the space.
# Otherwise throw an error
if job_name == "emr-cli job":
job_name = "emr-cli_job"
elif not re.fullmatch("[\.\-_/#A-Za-z0-9]+", job_name):
console_log(f"Invalid characters in job name {job_name} - EMR on EKS must match [\.\-_/#A-Za-z0-9]+")
sys.exit(1)

jobDriver = {
"sparkSubmitJobDriver": {
"entryPoint": self.dp.entrypoint_uri(),
}
}
spark_submit_parameters = self.dp.spark_submit_parameters().params_for("emr_eks")

if spark_submit_opts:
spark_submit_parameters = f"{spark_submit_parameters} {spark_submit_opts}".strip()

if spark_submit_parameters:
jobDriver["sparkSubmitJobDriver"]["sparkSubmitParameters"] = spark_submit_parameters

if job_args:
jobDriver["sparkSubmitJobDriver"]["entryPointArguments"] = job_args # type: ignore

config_overrides = {}
if s3_logs_uri:
config_overrides = {"monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": s3_logs_uri}}}

response = self.client.start_job_run(
virtualClusterId=self.virtual_cluster_id,
executionRoleArn=self.job_role,
name=job_name,
jobDriver=jobDriver,
configurationOverrides=config_overrides,
releaseLabel="emr-6.15.0-latest",
)
job_run_id = response.get("id")

console_log(f"Job submitted to EMR Virtual Cluster (Job Run ID: {job_run_id})")
if not wait and not show_logs:
return job_run_id

console_log("Waiting for job to complete...")
job_done = False
job_state = "SUBMITTED"
jr_response = {}
while not job_done:
jr_response = self.get_job_run(job_run_id)
new_state = jr_response.get("state")
if new_state != job_state:
console_log(f"Job state is now: {new_state}")
job_state = new_state
job_done = new_state in [
"COMPLETED",
"FAILED",
"CANCEL_PENDING",
"CANCELLED",
]
sleep(2)

if show_logs:
console_log(f"stdout for {job_run_id}\n{'-'*38}")
log_location = join(
f"{s3_logs_uri}",
self.virtual_cluster_id,
"jobs",
job_run_id,
"containers",
f"spark-{job_run_id}",
f"spark-{job_run_id}-driver",
"stdout.gz",
)
print_s3_gz(self.s3_client, log_location)

if jr_response.get("state") != "COMPLETED":
console_log(f"EMR Containers job failed: {jr_response.get('stateDetails')}")
sys.exit(1)
console_log("Job completed successfully!")

return job_run_id

def get_job_run(self, job_run_id: str) -> dict:
response = self.client.describe_job_run(virtualClusterId=self.virtual_cluster_id, id=job_run_id)
return response.get("jobRun")
64 changes: 33 additions & 31 deletions src/emr_cli/deployments/emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
from typing import List, Optional

import boto3

from emr_cli.deployments import SparkParams
from emr_cli.utils import console_log, find_files, mkdir, print_s3_gz


class DeploymentPackage(metaclass=abc.ABCMeta):
def __init__(
self, entry_point_path: str = "entrypoint.py", s3_target_uri: str = ""
) -> None:
def __init__(self, entry_point_path: str = "entrypoint.py", s3_target_uri: str = "") -> None:
self.entry_point_path = entry_point_path
self.dist_dir = "dist"

Expand Down Expand Up @@ -92,11 +89,32 @@ def _create_s3_buckets(self):
Creates both the source and log buckets if they don't already exist.
"""
for bucket_name in set([self.code_bucket, self.log_bucket]):
self.s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={
'LocationConstraint': self.s3_client.meta.region_name # type: ignore
})

self.s3_client.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={
"LocationConstraint": self.s3_client.meta.region_name # type: ignore
},
)

console_log(f"Created S3 bucket: s3://{bucket_name}")
self.s3_client.put_bucket_policy(Bucket=bucket_name, Policy=self._default_s3_bucket_policy(bucket_name))

def _default_s3_bucket_policy(self, bucket_name) -> str:
bucket_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "RequireSecureTransport",
"Effect": "Deny",
"Principal": "*",
"Resource": [f"arn:aws:s3:::{bucket_name}/*", f"arn:aws:s3:::{bucket_name}"],
"Condition": {
"Bool": {"aws:SecureTransport": "false", "aws:SourceArn": f"arn:aws:s3:::{bucket_name} "}
},
}
],
}
return json.dumps(bucket_policy)

def _create_job_role(self):
# First create a role that can be assumed by EMR Serverless jobs
Expand All @@ -118,19 +136,13 @@ def _create_job_role(self):
role_arn = response.get("Role").get("Arn")
console_log(f"Created IAM Role: {role_arn}")

self.iam_client.attach_role_policy(
RoleName=self.job_role_name, PolicyArn=self._create_s3_policy()
)
self.iam_client.attach_role_policy(
RoleName=self.job_role_name, PolicyArn=self._create_glue_policy()
)
self.iam_client.attach_role_policy(RoleName=self.job_role_name, PolicyArn=self._create_s3_policy())
self.iam_client.attach_role_policy(RoleName=self.job_role_name, PolicyArn=self._create_glue_policy())

return role_arn

def _create_s3_policy(self):
bucket_arns = [
f"arn:aws:s3:::{name}" for name in [self.code_bucket, self.log_bucket]
]
bucket_arns = [f"arn:aws:s3:::{name}" for name in [self.code_bucket, self.log_bucket]]
policy_doc = {
"Version": "2012-10-17",
"Statement": [
Expand Down Expand Up @@ -239,14 +251,10 @@ def run_job(
"entryPoint": self.dp.entrypoint_uri(),
}
}
spark_submit_parameters = self.dp.spark_submit_parameters().params_for(
"emr_serverless"
)
spark_submit_parameters = self.dp.spark_submit_parameters().params_for("emr_serverless")

if spark_submit_opts:
spark_submit_parameters = (
f"{spark_submit_parameters} {spark_submit_opts}".strip()
)
spark_submit_parameters = f"{spark_submit_parameters} {spark_submit_opts}".strip()

if spark_submit_parameters:
jobDriver["sparkSubmit"]["sparkSubmitParameters"] = spark_submit_parameters
Expand All @@ -256,11 +264,7 @@ def run_job(

config_overrides = {}
if s3_logs_uri:
config_overrides = {
"monitoringConfiguration": {
"s3MonitoringConfiguration": {"logUri": s3_logs_uri}
}
}
config_overrides = {"monitoringConfiguration": {"s3MonitoringConfiguration": {"logUri": s3_logs_uri}}}

response = self.client.start_job_run(
applicationId=self.application_id,
Expand Down Expand Up @@ -314,7 +318,5 @@ def run_job(
return job_run_id

def get_job_run(self, job_run_id: str) -> dict:
response = self.client.get_job_run(
applicationId=self.application_id, jobRunId=job_run_id
)
response = self.client.get_job_run(applicationId=self.application_id, jobRunId=job_run_id)
return response.get("jobRun")
Loading

0 comments on commit 1dbf3aa

Please sign in to comment.