diff --git a/tap_athena/client.py b/tap_athena/client.py index c5293c1..3c6ccbb 100644 --- a/tap_athena/client.py +++ b/tap_athena/client.py @@ -58,3 +58,80 @@ class AthenaStream(SQLStream): """The Stream class for Athena.""" connector_class = AthenaConnector + + # Get records from stream + def get_records(self, context: dict | None) -> t.Iterable[dict[str, t.Any]]: + """Return a generator of record-type dictionary objects. + + If the stream has a replication_key value defined, records will be sorted by the + incremental key. If the stream also has an available starting bookmark, the + records will be filtered for values greater than or equal to the bookmark value. + + Args: + context: If partition context is provided, will read specifically from this + data slice. + + Yields: + One dict per record. + + Raises: + NotImplementedError: If partition is passed in context and the stream does + not support partitioning. + """ + if context: + msg = f"Stream '{self.name}' does not support partitioning." + raise NotImplementedError( + msg, + ) + + selected_column_names = self.get_selected_schema()["properties"].keys() + table = self.connector.get_table( + full_table_name=self.fully_qualified_name, + column_names=selected_column_names, + ) + query = table.select() + + if self.config["paginate"] or self.replication_key: + if self.config["paginate"] and not self.replication_key: + msg = "Replication key is required when paginate is set." + raise Exception(msg) + replication_key_col = table.columns[self.replication_key] + query = query.order_by(replication_key_col) + + start_val = self.get_starting_replication_key_value(context) + if start_val: + query = query.where( + sqlalchemy.text(":replication_key >= :start_val").bindparams( + replication_key=replication_key_col, + start_val=start_val, + ), + ) + + if self.ABORT_AT_RECORD_COUNT is not None: + # Limit record count to one greater than the abort threshold. This ensures + # `MaxRecordsLimitException` exception is properly raised by caller + # `Stream._sync_records()` if more records are available than can be + # processed. + query = query.limit(self.ABORT_AT_RECORD_COUNT + 1) + + if self.config["paginate"]: + batch_start = 0 + batch_size = self.config["paginate_batch_size"] + batch_end = batch_size + with self.connector._connect() as conn: + record_count = 0 + while True: + full_query = query.limit(batch_end).offset(batch_start) + for record in conn.execute(full_query): + yield dict(record._mapping) + record_count += 1 + if record_count < batch_size: + break + else: + batch_end = batch_end + batch_size + batch_start = batch_start + batch_size + record_count = 0 + else: + with self.connector._connect() as conn: + for record in conn.execute(query): + yield dict(record._mapping) diff --git a/tap_athena/tap.py b/tap_athena/tap.py index 3e8d0ab..f8a4c7d 100644 --- a/tap_athena/tap.py +++ b/tap_athena/tap.py @@ -45,6 +45,20 @@ class TapAthena(SQLTap): required=True, description="Athena schema name", ), + th.Property( + "paginate", + th.BooleanType, + required=False, + description="Whether to use limit/offset pagination when querying Athena. This is useful for large tables where the initial query runs for a long time.", + default=False, + ), + th.Property( + "paginate_batch_size", + th.IntegerType, + required=False, + description="The size of the batches if using pagination. The larger the batches the longer the tap will wait for Athena to return records.", + default=10000, + ), ).to_dict()