Skip to content

Commit

Permalink
Update gcs_source.py
Browse files Browse the repository at this point in the history
  • Loading branch information
acrylJonny authored Mar 3, 2025
1 parent a76d8eb commit a2a08b2
Showing 1 changed file with 16 additions and 51 deletions.
67 changes: 16 additions & 51 deletions metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from typing import Any, Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional
from urllib.parse import unquote

from pandas import DataFrame
from pydantic import Field, SecretStr, validator

from datahub.configuration.common import ConfigModel
Expand All @@ -21,12 +20,8 @@
from datahub.ingestion.source.data_lake_common.config import PathSpecsConfigMixin
from datahub.ingestion.source.data_lake_common.data_lake_utils import PLATFORM_GCS
from datahub.ingestion.source.data_lake_common.path_spec import PathSpec, is_gcs_uri
from datahub.ingestion.source.gcs.gcs_utils import (
get_gcs_bucket_name,
get_gcs_bucket_relative_path,
)
from datahub.ingestion.source.gcs.gcs_utils import strip_gcs_prefix
from datahub.ingestion.source.s3.config import DataLakeSourceConfig
from datahub.ingestion.source.s3.datalake_profiler_config import DataLakeProfilerConfig
from datahub.ingestion.source.s3.report import DataLakeSourceReport
from datahub.ingestion.source.s3.source import S3Source, TableData
from datahub.ingestion.source.state.stale_entity_removal_handler import (
Expand Down Expand Up @@ -63,19 +58,6 @@ class GCSSourceConfig(
description="Number of files to list to sample for schema inference. This will be ignored if sample_files is set to False in the pathspec.",
)

profiling: Optional[DataLakeProfilerConfig] = Field(
default=DataLakeProfilerConfig(), description="Data profiling configuration"
)

spark_driver_memory: str = Field(
default="4g", description="Max amount of memory to grant Spark."
)

spark_config: Dict[str, Any] = Field(
description="Spark configuration properties",
default={},
)

stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None

@validator("path_specs", always=True)
Expand All @@ -91,9 +73,6 @@ def check_path_specs_and_infer_platform(

return path_specs

def is_profiling_enabled(self) -> bool:
return self.profiling is not None and self.profiling.enabled


class GCSSourceReport(DataLakeSourceReport):
pass
Expand All @@ -104,7 +83,7 @@ class GCSSourceReport(DataLakeSourceReport):
@support_status(SupportStatus.INCUBATING)
@capability(SourceCapability.CONTAINERS, "Enabled by default")
@capability(SourceCapability.SCHEMA_METADATA, "Enabled by default")
@capability(SourceCapability.DATA_PROFILING, "Enabled via configuration")
@capability(SourceCapability.DATA_PROFILING, "Not supported", supported=False)
class GCSSource(StatefulIngestionSourceBase):
def __init__(self, config: GCSSourceConfig, ctx: PipelineContext):
super().__init__(config, ctx)
Expand Down Expand Up @@ -132,11 +111,6 @@ def create_equivalent_s3_config(self):
env=self.config.env,
max_rows=self.config.max_rows,
number_of_files_to_sample=self.config.number_of_files_to_sample,
profiling=self.config.profiling,
spark_driver_memory=self.config.spark_driver_memory,
spark_config=self.config.spark_config,
use_s3_bucket_tags=False,
use_s3_object_tags=False,
)
return s3_config

Expand Down Expand Up @@ -172,30 +146,21 @@ def s3_source_overrides(self, source: S3Source) -> S3Source:
source.create_s3_path = lambda bucket_name, key: unquote( # type: ignore
f"s3://{bucket_name}/{key}"
)

if self.config.is_profiling_enabled():
original_read_file_spark = source.read_file_spark

from types import MethodType

def read_file_spark_with_gcs(
self_source: S3Source, file: str, ext: str
) -> Optional[DataFrame]:
# Convert s3:// path back to gs:// for Spark
if file.startswith("s3://"):
file = f"gs://{file[5:]}"
return original_read_file_spark(file, ext)

source.read_file_spark = MethodType(read_file_spark_with_gcs, source) # type: ignore

def get_external_url_override(table_data: TableData) -> Optional[str]:
bucket_name = get_gcs_bucket_name(table_data.table_path)
key_prefix = get_gcs_bucket_relative_path(table_data.table_path)
return f"https://console.cloud.google.com/storage/browser/{bucket_name}/{key_prefix}"

source.get_external_url = get_external_url_override # type: ignore
source.get_external_url = self.get_external_url_override.__get__(source) # type: ignore
return source

def get_external_url_override(self, table_data: TableData) -> Optional[str]:
"""
Convert S3 URIs back to GCS URIs for external URLs.
This method gets bound to the S3Source instance.
"""
if not table_data.table_path.startswith("s3://"):
return None

# Replace the s3:// with gs:// to create the GCS URI
gcs_uri = table_data.table_path.replace("s3://", "gs://")
return f"https://console.cloud.google.com/storage/browser/{strip_gcs_prefix(gcs_uri)}"

def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
return [
*super().get_workunit_processors(),
Expand Down

0 comments on commit a2a08b2

Please sign in to comment.