diff --git a/python/base_connector/auth.py b/python/base_connector/auth.py new file mode 100644 index 0000000000..3ba7681812 --- /dev/null +++ b/python/base_connector/auth.py @@ -0,0 +1,46 @@ +from requests import post +import base64 + +from airbyte_cdk.sources.streams.http.requests_native_auth import Oauth2Authenticator as BaseOauth2Authenticator + + +class Oauth2Authenticator(BaseOauth2Authenticator): + grant_type = "refresh_token" + + def __init__(self, use_base64, content_type, use_body, refresh_token_fields, **kwargs): + super().__init__(**kwargs) + self.token_refresh_endpoint = kwargs["token_refresh_endpoint"] + self.grant_type = kwargs["grant_type"] + self.client_id = kwargs["client_id"] + self.client_secret = kwargs["client_secret"] + self.refresh_token = kwargs["refresh_token"] + + self.use_base64 = use_base64 + self.content_type = content_type + self.use_body = use_body + self.refresh_token_fields = refresh_token_fields + + def refresh_access_token(self): + auth_request = { + "headers": {}, + "data": {} + } + + if self.use_body is True: + for token_field in self.refresh_token_fields: + auth_request["data"][token_field] = self.__dict__[token_field] + + if self.content_type is not None: + auth_request["headers"]["Content-Type"] = self.content_type + + if self.use_base64 is not None: + auth_hash = base64.b64encode(f"{self.client_id}:{self.client_secret}".encode('ascii')).decode('ascii') + auth_request["headers"]["Authorization"] = f"Basic {auth_hash}" + + response = post(self.token_refresh_endpoint, **auth_request) + response_body = response.json() + + if response.status_code >= 400: + raise Exception(response_body) + + return response_body["access_token"], response_body["expires_in"] \ No newline at end of file diff --git a/python/base_connector/source.py b/python/base_connector/source.py new file mode 100644 index 0000000000..d883c65372 --- /dev/null +++ b/python/base_connector/source.py @@ -0,0 +1,112 @@ +import yaml + +from typing import Any, List, Mapping, Tuple, Union + +from airbyte_cdk import AirbyteLogger +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources import AbstractSource +from airbyte_cdk.sources.streams import Stream +from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator + +from .stream import BaseStream +from .auth import Oauth2Authenticator + +class BaseSource(AbstractSource): + check_stream = None + schema_streams = [] + content_type = None + use_body = None + refresh_token_fields = [] + + def __init__(self, connector: str): + super().__init__() + self.__class__.__module__ = connector.replace("-", "_") + self.connector = connector + + with open(f"{connector}/{self.__class__.__module__}/estuary-manifest.yaml", "r") as file: + manifest = yaml.safe_load(file) + + self.base_endpoint = manifest["definitions"]["url_base"] + self.token_refresh_endpoint = manifest["definitions"]["url_refresh_token"] + self.grant_type = manifest["definitions"]["grant_type"] + + self.paginator = manifest["paginator"] + self.streams_dict = manifest["streams"] + self.check_stream = manifest["check_stream"] + self.use_base64 = "use_base64" in manifest["definitions"] and manifest["definitions"]["use_base64"] is True + + if "content_type" in manifest["definitions"]: + self.content_type = manifest["definitions"]["content_type"] + + if "use_body" in manifest["definitions"]: + self.use_body = manifest["definitions"]["use_body"] + + if "refresh_token_fields" in manifest["definitions"]: + self.refresh_token_fields = manifest["definitions"]["refresh_token_fields"] + + def get_args(self, config: Mapping[str, Any]): + args = { + "connector": self.connector, + "authenticator": self._get_authenticator(config), + "paginator": self.paginator, + "url_base": self.base_endpoint + } + + return args + + def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, Any]: + try: + check_stream = BaseStream( + config_dict = self.streams_dict[self.check_stream], + **self.get_args(config) + ) + + next(check_stream.read_records(sync_mode = SyncMode.full_refresh)) + return True, None + except Exception as e: + return False, e + + def _get_authenticator(self, config: dict) -> Union[TokenAuthenticator, Oauth2Authenticator]: + if "access_token" in config: + return TokenAuthenticator(token = config["access_token"]) + + creds = config.get("credentials") + + if "personal_access_token" in creds: + return TokenAuthenticator(token = creds["personal_access_token"]) + else: + return Oauth2Authenticator( + token_refresh_endpoint = self.token_refresh_endpoint, + grant_type = self.grant_type, + client_id = creds["client_id"], + client_secret = creds["client_secret"], + refresh_token = creds["refresh_token"], + use_base64 = self.use_base64, + content_type = self.content_type, + use_body = self.use_body, + refresh_token_fields = self.refresh_token_fields, + ) + + def generate_stream(self, config_dict, args): + stream = next((x for x in self.schema_streams if x.name == config_dict["name"]), None) + if stream is not None: + return stream + + stream = BaseStream(config_dict = config_dict, **args) + if "schema_name" in config_dict: + self.schema_streams.append(stream) + + return stream + + def streams(self, config: Mapping[str, Any]) -> List[Stream]: + args = self.get_args(config) + + for _, config_dict in self.streams_dict.items(): + if "parent_streams" in config_dict: + for index, parent_config_dict in enumerate(config_dict["parent_streams"]): + parent_stream = self.generate_stream(self.streams_dict[parent_config_dict["name"]], args) + config_dict["parent_streams"][index]["instance"] = parent_stream + + self.generate_stream(config_dict, args) + + return self.schema_streams \ No newline at end of file diff --git a/python/base_connector/stream.py b/python/base_connector/stream.py new file mode 100644 index 0000000000..14c7d64476 --- /dev/null +++ b/python/base_connector/stream.py @@ -0,0 +1,186 @@ +from typing import Any, Iterable, Mapping, MutableMapping, Optional, Type +from requests import Response +import json + +from airbyte_cdk.models import SyncMode +from airbyte_cdk.sources.streams.http import HttpStream + +class BaseStream(HttpStream): + url_base = "" + primary_key = None + use_cache = False + name = "base_stream" + raise_on_http_errors = True + parent_streams_configs = None + + @property + def StreamType(self) -> Type: + return self.__class__ + + def __init__(self, connector: str, config_dict: dict, paginator: dict, url_base: str, **kwargs): + super().__init__(**kwargs) + + self.connector = connector + self.page_size = paginator["page_size"] + self.paginator_opt_field_name = paginator["opt_field_name"] + self.paginator_request_field_name = paginator["request_field_name"] + self.has_pagination = config_dict["has_pagination"] if "has_pagination" in config_dict else False + + self.url_path = config_dict["path"] + self.url_base = url_base + self.primary_key = config_dict["primary_key"] if "primary_key" in config_dict else None + self.use_cache = "use_cache" in config_dict or "schema_name" not in config_dict + self.name = config_dict['name'] + self.schema_name = config_dict["schema_name"] if "schema_name" in config_dict else None + self.record_extractor = config_dict["record_extractor"] if "record_extractor" in config_dict else None + self.extra_opt = config_dict["extra_opt"] if "extra_opt" in config_dict else None + + self.raise_on_http_errors = "ignore_error" not in config_dict + + if "parent_streams" in config_dict: + self.parent_streams_configs = config_dict["parent_streams"] + + if "use_sync_token" in config_dict: + self.sync_token = None + + def get_json_schema(self) -> Mapping[str, Any]: + with open(f"{self.connector}/{self.connector.replace('-', '_')}/schemas/{self.schema_name}", "r") as file: + try: + json_schema = json.load(file) + except json.JSONDecodeError as error: + raise ValueError(f"Could not read json spec file: {error}. Please ensure that it is a valid JSON.") + + return json_schema + + def path(self, stream_slice = None, **kwargs: Mapping[str, Any]) -> str: + if self.parent_streams_configs is not None: + path = self.url_path + + for parent_stream_configs in self.parent_streams_configs: + partition_field = parent_stream_configs["partition_field"] + if partition_field in stream_slice: + path = path.format(**{ partition_field: stream_slice[partition_field] }) + + return path + + return self.url_path + + def next_page_token(self, response: Response) -> Optional[Mapping[str, Any]]: + if not self.has_pagination: + return None + + decoded_response = response.json() + + if self.check_use_sync(): + last_sync = decoded_response.get("sync") + + if last_sync: + return { "sync": last_sync } + + next_page = decoded_response.get(self.paginator_request_field_name) + + if next_page: + return { "offset": next_page["offset"] } + + def stream_slices(self, iteration = 0, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]: + if self.parent_streams_configs is not None: + yield from self.read_slices_from_records(self.parent_streams_configs[iteration]) + if iteration < len(self.parent_streams_configs): + self.stream_slices(iteration = iteration + 1, **kwargs) + else: + yield [None] + + def parse_response(self, response: Response, **kwargs: Mapping[str, Any]) -> Iterable[Mapping]: + response_json = response.json() + + if self.check_use_sync() and response.status_code == 412: + if "sync" in response_json: + self.sync_token = response_json["sync"] + else: + self.sync_token = None + else: + if "code" in response_json: + return + + if "sync" in response_json: + self.sync_token = response_json["sync"] + + if self.record_extractor is not None: + section_data = response_json.get(self.record_extractor, []) + else: + section_data = [response.json()] + + if isinstance(section_data, dict): + yield section_data + elif isinstance(section_data, list): + yield from section_data + + def request_params(self, next_page_token: Mapping[str, Any] = None, stream_slice: Mapping[str, Any] = None, **kwargs: Mapping[str, Any]) -> MutableMapping[str, Any]: + params = { self.paginator_opt_field_name: self.page_size } + params.update(self.get_opt_fields()) + + if self.extra_opt is not None: + for key, value in self.extra_opt.items(): + if value in stream_slice: + params.update({ key: stream_slice[value] }) + else: + params.update({ key: value }) + + if next_page_token: + params.update(next_page_token) + + return params + + def _handle_object_type(self, prop: str, value: MutableMapping[str, Any]) -> str: + return f"{prop}.id" + + def _handle_array_type(self, prop: str, value: MutableMapping[str, Any]) -> str: + if "type" in value and "object" in value["type"]: + return self._handle_object_type(prop, value) + + return prop + + def get_opt_fields(self) -> MutableMapping[str, str]: + if self.schema_name is None: + return { "opt_fields": "" } + + opt_fields = list() + schema = self.get_json_schema() + + for prop, value in schema["properties"].items(): + if "object" in value["type"]: + opt_fields.append(self._handle_object_type(prop, value)) + elif "array" in value["type"]: + opt_fields.append(self._handle_array_type(prop, value.get("items", []))) + else: + opt_fields.append(prop) + + return { "opt_fields": ",".join(opt_fields) } if opt_fields else dict() + + def read_slices_from_records(self, stream) -> Iterable[Optional[Mapping[str, Any]]]: + stream_instance = stream["instance"] + stream_slices = stream_instance.stream_slices(sync_mode = SyncMode.full_refresh) + + for stream_slice in stream_slices: + for record in stream_instance.read_records(sync_mode = SyncMode.full_refresh, stream_slice = stream_slice): + yield { stream["partition_field"]: record[stream["parent_key"]] } + + def read_records(self, *args, **kwargs): + if self.check_use_sync() and self.sync_token is not None: + kwargs["next_page_token"] = { "sync": self.sync_token } + + yield from super().read_records(*args, **kwargs) + + if self.check_use_sync(): + self.sync_token = self.get_latest_sync_token() + + def get_latest_sync_token(self) -> str: + latest_sync_token = self.state.get("last_sync_token") + + if latest_sync_token is None: + return None + + return latest_sync_token["sync"] + + def check_use_sync(self): + return "sync_token" in self.__dict__ \ No newline at end of file