diff --git a/tests/test_aws_keyspaces_client.py b/tests/test_aws_keyspaces_client.py index 1a62dd5..e2a85b8 100644 --- a/tests/test_aws_keyspaces_client.py +++ b/tests/test_aws_keyspaces_client.py @@ -1,6 +1,7 @@ from datetime import datetime from uptime_service_validation.coordinator.aws_keyspaces_client import ( AWSKeyspacesClient, + ShardCalculator, ) @@ -19,3 +20,50 @@ def test_get_submitted_at_date_list(): end = datetime(2023, 11, 8, 0, 0, 0) result = AWSKeyspacesClient.get_submitted_at_date_list(start, end) assert result == ["2023-11-06", "2023-11-07", "2023-11-08"] + + +def test_calculate_shard(): + matrix = [ + [0, 0, 0, 0], + [0, 0, 1, 0], + [0, 2, 59, 1], + [12, 1, 3, 300], + [15, 10, 0, 379], + [23, 22, 23, 584], + [23, 59, 59, 599], + ] + for i in matrix: + hour = i[0] + minute = i[1] + second = i[2] + expected = i[3] + assert ShardCalculator.calculate_shard(hour, minute, second) == expected + + +def test_calculate_shards_in_range(): + # 1 between days + start_time = datetime(2024, 2, 29, 23, 58, 29) + end_time = datetime(2024, 3, 1, 0, 3, 29) + expected_cql_statement = "shard in (0,1,599)" + result_cql_statement = ShardCalculator.calculate_shards_in_range( + start_time, end_time + ) + assert result_cql_statement == expected_cql_statement + + # 2 within the same day + start_time = datetime(2024, 2, 29, 12, 58, 1) + end_time = datetime(2024, 2, 29, 12, 59, 59) + expected_cql_statement = "shard in (324)" + result_cql_statement = ShardCalculator.calculate_shards_in_range( + start_time, end_time + ) + assert result_cql_statement == expected_cql_statement + + # 2 shard boundary + start_time = datetime(2024, 2, 29, 0, 0, 0) + end_time = datetime(2024, 2, 29, 0, 2, 24) + expected_cql_statement = "shard in (0,1)" + result_cql_statement = ShardCalculator.calculate_shards_in_range( + start_time, end_time + ) + assert result_cql_statement == expected_cql_statement diff --git a/uptime_service_validation/coordinator/aws_keyspaces_client.py b/uptime_service_validation/coordinator/aws_keyspaces_client.py index f651576..4db7674 100644 --- a/uptime_service_validation/coordinator/aws_keyspaces_client.py +++ b/uptime_service_validation/coordinator/aws_keyspaces_client.py @@ -1,13 +1,15 @@ import os import boto3 +import time +import random from cassandra import ProtocolVersion from cassandra.auth import PlainTextAuthProvider from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra_sigv4.auth import SigV4AuthProvider -from cassandra.policies import DCAwareRoundRobinPolicy +from cassandra.policies import DCAwareRoundRobinPolicy, RetryPolicy from ssl import SSLContext, CERT_REQUIRED, PROTOCOL_TLS_CLIENT from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timedelta from typing import Optional, ByteString, List import pandas as pd @@ -58,20 +60,41 @@ def __init__(self): self.aws_ssl_certificate_path = os.environ.get("SSL_CERTFILE") self.aws_region = self.cassandra_host.split(".")[1] self.ssl_context = self._create_ssl_context() + self.request_timeout = 20.0 if self.cassandra_user and self.cassandra_pass: self.auth_provider = PlainTextAuthProvider( username=self.cassandra_user, password=self.cassandra_pass ) + profile = ExecutionProfile( + # assuming this is for hosted Cassandra, load balancing policy to be determined + # load_balancing_policy=DCAwareRoundRobinPolicy(local_dc=self.aws_region), + retry_policy=ExponentialBackOffRetryPolicy(), + request_timeout=self.request_timeout, + ) self.cluster = Cluster( [self.cassandra_host], ssl_context=self.ssl_context, auth_provider=self.auth_provider, port=int(self.cassandra_port), + execution_profiles={EXEC_PROFILE_DEFAULT: profile}, + protocol_version=ProtocolVersion.V4, ) else: self.auth_provider = self._create_sigv4auth_provider() - self.cluster = self._create_cluster() + profile = ExecutionProfile( + load_balancing_policy=DCAwareRoundRobinPolicy(local_dc=self.aws_region), + retry_policy=ExponentialBackOffRetryPolicy(), + request_timeout=self.request_timeout, + ) + self.cluster = Cluster( + [self.cassandra_host], + ssl_context=self.ssl_context, + auth_provider=self.auth_provider, + port=int(self.cassandra_port), + execution_profiles={EXEC_PROFILE_DEFAULT: profile}, + protocol_version=ProtocolVersion.V4, + ) def _create_ssl_context(self): ssl_context = SSLContext(PROTOCOL_TLS_CLIENT) @@ -118,19 +141,6 @@ def _create_sigv4auth_provider(self): ) return SigV4AuthProvider(boto_session) - def _create_cluster(self): - profile = ExecutionProfile( - load_balancing_policy=DCAwareRoundRobinPolicy(local_dc=self.aws_region) - ) - return Cluster( - [self.cassandra_host], - ssl_context=self.ssl_context, - auth_provider=self.auth_provider, - port=int(self.cassandra_port), - execution_profiles={EXEC_PROFILE_DEFAULT: profile}, - protocol_version=ProtocolVersion.V4, - ) - def connect(self): self.session = self.cluster.connect() @@ -208,6 +218,10 @@ def get_submissions( submitted_at_start, submitted_at_end ) + shard_condition = ShardCalculator.calculate_shards_in_range( + submitted_at_start, submitted_at_end + ) + if len(submitted_at_date_list) == 1: submitted_at_date = submitted_at_date_list[0] else: @@ -224,6 +238,10 @@ def get_submissions( parameters.append(submitted_at_date) elif submitted_at_dates: conditions.append(f"submitted_at_date IN ({submitted_at_dates})") + + # Add shard condition here since we have a submitted_at_date or submitted_at_dates + conditions.append(shard_condition) + if submitted_at_start: start_operator = ">=" if start_inclusive else ">" conditions.append(f"submitted_at {start_operator} %s") @@ -273,6 +291,93 @@ def close(self): self.cluster.shutdown() +class ExponentialBackOffRetryPolicy(RetryPolicy): + def __init__(self, base_delay=0.1, max_delay=10, max_retries=10): + self.base_delay = base_delay # seconds + self.max_delay = max_delay # seconds + self.max_retries = max_retries + + def get_backoff_time(self, retry_num): + # Calculate exponential backoff time + delay = min(self.max_delay, self.base_delay * (2**retry_num)) + # Add some randomness to avoid thundering herd problem + jitter = random.uniform(0, 0.1) * delay + return delay + jitter + + def on_read_timeout( + self, + query, + consistency, + required_responses, + received_responses, + data_retrieved, + retry_num, + ): + if retry_num >= self.max_retries: + return (self.RETHROW, None) + time.sleep(self.get_backoff_time(retry_num)) + return (self.RETRY, consistency) + + def on_write_timeout( + self, + query, + consistency, + write_type, + required_responses, + received_responses, + retry_num, + ): + if retry_num >= self.max_retries: + return (self.RETHROW, None) + time.sleep(self.get_backoff_time(retry_num)) + return (self.RETRY, consistency) + + def on_unavailable( + self, query, consistency, required_replica, alive_replica, retry_num + ): + if retry_num >= self.max_retries: + return (self.RETHROW, None) + time.sleep(self.get_backoff_time(retry_num)) + return (self.RETRY_NEXT_HOST, None) + + +class ShardCalculator: + @classmethod + def calculate_shard(cls, hour, minute, second): + return (3600 * hour + 60 * minute + second) // 144 + + @classmethod + def calculate_shards_in_range(cls, start_time, end_time): + shards = set() + current_time = start_time + + while current_time < end_time: + shard = cls.calculate_shard( + current_time.hour, current_time.minute, current_time.second + ) + shards.add(shard) + # Move to the next second + current_time += timedelta(seconds=1) + + # Check if endTime falls exactly on a new shard boundary and add it if necessary + end_shard = cls.calculate_shard(end_time.hour, end_time.minute, end_time.second) + if end_shard not in shards: + # Check if end_time is exactly on the boundary of a new shard + total_seconds_end = ( + (end_time.hour * 3600) + (end_time.minute * 60) + end_time.second + ) + if total_seconds_end % 144 == 0: + shards.add(end_shard) + + # Convert the set of unique shards into a sorted list for readability + shards_list = sorted(list(shards)) + # Format the shards into a CQL statement string + shards_list = sorted(list(shards)) # Sort the shards for readability + shards_str = ",".join(map(str, shards_list)) + cql_statement = f"shard in ({shards_str})" + return cql_statement + + # Usage Example if __name__ == "__main__": client = AWSKeyspacesClient() diff --git a/uptime_service_validation/coordinator/coordinator.py b/uptime_service_validation/coordinator/coordinator.py index 603bbfd..b13ff04 100644 --- a/uptime_service_validation/coordinator/coordinator.py +++ b/uptime_service_validation/coordinator/coordinator.py @@ -2,6 +2,7 @@ the validator processes, distribute work them and, when they're done, collect their results, compute scores for the delegation program and put the results in the database.""" + from dataclasses import asdict from datetime import datetime, timedelta, timezone import logging @@ -21,7 +22,7 @@ create_graph, apply_weights, bfs, - send_slack_message + send_slack_message, ) from uptime_service_validation.coordinator.server import ( bool_env_var_set, @@ -33,8 +34,7 @@ ) # Add project root to python path -project_root = os.path.abspath(os.path.join( - os.path.dirname(__file__), "..", "..")) +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) sys.path.insert(0, project_root) @@ -56,8 +56,7 @@ def wait_until_batch_ends(self): "If the time window if the current batch is not yet over, sleep until it is." if self.batch.end_time > self.current_timestamp: delta = timedelta(minutes=2) - sleep_interval = (self.batch.end_time - - self.current_timestamp) + delta + sleep_interval = (self.batch.end_time - self.current_timestamp) + delta time_until = self.current_timestamp + sleep_interval logging.info( "All submissions are processed till date. " @@ -105,45 +104,49 @@ def __warn_if_work_took_longer_then_expected(self): self.current_timestamp, ) -def load_submissions(batch): + +def load_submissions(time_intervals): """Load submissions from Cassandra and return them as a DataFrame.""" submissions = [] submissions_verified = [] cassandra = AWSKeyspacesClient() + try: cassandra.connect() - submissions = cassandra.get_submissions( - submitted_at_start=batch.start_time, - submitted_at_end=batch.end_time, - start_inclusive=True, - end_inclusive=False, - ) - # for further processing - # we use only submissions verified = True and validation_error = None - for submission in submissions: - if submission.verified and submission.validation_error is None: - submissions_verified.append(submission) + for time_interval in time_intervals: + submissions.extend( + cassandra.get_submissions( + submitted_at_start=time_interval[0], + submitted_at_end=time_interval[1], + start_inclusive=True, + end_inclusive=False, + ) + ) except Exception as e: logging.error("Error in loading submissions: %s", e) return pd.DataFrame([]) finally: cassandra.close() + # for further processing + # we use only submissions verified = True and validation_error = None or "" + for submission in submissions: + if submission.verified and ( + submission.validation_error is None or submission.validation_error == "" + ): + submissions_verified.append(submission) + all_submissions_count = len(submissions) submissions_to_process_count = len(submissions_verified) logging.info("number of all submissions: %s", all_submissions_count) - logging.info( - "number of submissions to process: %s", - submissions_to_process_count - ) + logging.info("number of submissions to process: %s", submissions_to_process_count) if submissions_to_process_count < all_submissions_count: logging.warning( "some submissions were not processed, because they were not \ verified or had validation errors" ) - return pd.DataFrame( - [asdict(submission) for submission in submissions_verified] - ) + return pd.DataFrame([asdict(submission) for submission in submissions_verified]) + def process_statehash_df(db, batch, state_hash_df, verification_time): """Process the state hash dataframe and return the master dataframe.""" @@ -178,8 +181,7 @@ def process_statehash_df(db, batch, state_hash_df, verification_time): nodes_in_cur_batch = pd.DataFrame( master_df["submitter"].unique(), columns=["block_producer_key"] ) - node_to_insert = find_new_values_to_insert( - existing_nodes, nodes_in_cur_batch) + node_to_insert = find_new_values_to_insert(existing_nodes, nodes_in_cur_batch) if not node_to_insert.empty: node_to_insert["updated_at"] = datetime.now(timezone.utc) @@ -190,15 +192,13 @@ def process_statehash_df(db, batch, state_hash_df, verification_time): columns={ "file_updated": "file_timestamps", "submitter": "block_producer_key", - } + }, ) - relation_df, p_selected_node_df = db.get_previous_statehash( - batch.bot_log_id) + relation_df, p_selected_node_df = db.get_previous_statehash(batch.bot_log_id) p_map = list(get_relations(relation_df)) c_selected_node = filter_state_hash_percentage(master_df) - batch_graph = create_graph( - master_df, p_selected_node_df, c_selected_node, p_map) + batch_graph = create_graph(master_df, p_selected_node_df, c_selected_node, p_map) weighted_graph = apply_weights( batch_graph=batch_graph, c_selected_node=c_selected_node, @@ -217,8 +217,7 @@ def process_statehash_df(db, batch, state_hash_df, verification_time): # but it's not used anywhere inside the function) ) point_record_df = master_df[ - master_df["state_hash"].isin( - shortlisted_state_hash_df["state_hash"].values) + master_df["state_hash"].isin(shortlisted_state_hash_df["state_hash"].values) ] for index, row in shortlisted_state_hash_df.iterrows(): @@ -227,15 +226,13 @@ def process_statehash_df(db, batch, state_hash_df, verification_time): p_selected_node_df = shortlisted_state_hash_df.copy() parent_hash = [] for s in shortlisted_state_hash_df["state_hash"].values: - p_hash = master_df[master_df["state_hash"] == s][ - "parent_state_hash" - ].values[0] + p_hash = master_df[master_df["state_hash"] == s]["parent_state_hash"].values[0] parent_hash.append(p_hash) shortlisted_state_hash_df["parent_state_hash"] = parent_hash - p_map = list(get_relations( - shortlisted_state_hash_df[["parent_state_hash", "state_hash"]] - )) + p_map = list( + get_relations(shortlisted_state_hash_df[["parent_state_hash", "state_hash"]]) + ) if not point_record_df.empty: file_timestamp = master_df.iloc[-1]["file_timestamps"] else: @@ -251,7 +248,7 @@ def process_statehash_df(db, batch, state_hash_df, verification_time): file_timestamp, batch.start_time.timestamp(), batch.end_time.timestamp(), - verification_time.total_seconds() + verification_time.total_seconds(), ) bot_log_id = db.create_bot_log(values) @@ -260,8 +257,7 @@ def process_statehash_df(db, batch, state_hash_df, verification_time): if not point_record_df.empty: point_record_df.loc[:, "amount"] = 1 - point_record_df.loc[:, "created_at"] = datetime.now( - timezone.utc) + point_record_df.loc[:, "created_at"] = datetime.now(timezone.utc) point_record_df.loc[:, "bot_log_id"] = bot_log_id point_record_df = point_record_df[ [ @@ -287,7 +283,7 @@ def process(db, state): logging.info( "iteration start at: %s, cur_timestamp: %s", state.batch.start_time, - state.current_timestamp + state.current_timestamp, ) logging.info( "running for batch: %s - %s.", state.batch.start_time, state.batch.end_time @@ -295,8 +291,7 @@ def process(db, state): # sleep until batch ends, update the state accordingly, then continue. state.wait_until_batch_ends() - time_intervals = list(state.batch.split( - int(os.environ["MINI_BATCH_NUMBER"]))) + time_intervals = list(state.batch.split(int(os.environ["MINI_BATCH_NUMBER"]))) worker_image = os.environ["WORKER_IMAGE"] worker_tag = os.environ["WORKER_TAG"] @@ -312,7 +307,7 @@ def process(db, state): logging.info( "reading ZKValidator results from a db between the time range: %s - %s", state.batch.start_time, - state.batch.end_time + state.batch.end_time, ) logging.info("ZKValidator results read from a db in %s.", timer.duration) @@ -331,7 +326,7 @@ def process(db, state): logging, ) - state_hash_df = load_submissions(state.batch) + state_hash_df = load_submissions(time_intervals) if not state_hash_df.empty: try: bot_log_id = process_statehash_df( diff --git a/uptime_service_validation/coordinator/server.py b/uptime_service_validation/coordinator/server.py index 72fdd8c..197be64 100644 --- a/uptime_service_validation/coordinator/server.py +++ b/uptime_service_validation/coordinator/server.py @@ -127,11 +127,7 @@ def setUpValidatorPods(time_intervals, logging, worker_image, worker_tag): client.V1EnvVar(name="CASSANDRA_USE_SSL", value="1"), client.V1EnvVar( name="SSL_CERTFILE", - value=os.environ.get("SSL_CERTFILE"), - ), - client.V1EnvVar( - name="CQLSH", - value=os.environ.get("CQLSH"), + value="/root/.cassandra/sf-class2-root.crt", ), client.V1EnvVar( name="AUTH_VOLUME_MOUNT_PATH", @@ -153,6 +149,14 @@ def setUpValidatorPods(time_intervals, logging, worker_image, worker_tag): name="NO_CHECKS", value=os.environ.get("NO_CHECKS"), ), + client.V1EnvVar( + name="AWS_ACCESS_KEY_ID", + value=os.environ.get("AWS_ACCESS_KEY_ID"), + ), + client.V1EnvVar( + name="AWS_SECRET_ACCESS_KEY", + value=os.environ.get("AWS_SECRET_ACCESS_KEY"), + ), ] # Entrypoint configmap name @@ -171,24 +175,12 @@ def setUpValidatorPods(time_intervals, logging, worker_image, worker_tag): ), # 0777 permission in octal as int ) - cassandra_ssl_volume = client.V1Volume( - name="cassandra-crt", - secret=client.V1SecretVolumeSource( - secret_name="uptime-service-cassandra-crt" - ), - ) - # Define the volumeMounts auth_volume_mount = client.V1VolumeMount( name="auth-volume", mount_path=os.environ.get("AUTH_VOLUME_MOUNT_PATH"), ) - cassandra_ssl_volume_mount = client.V1VolumeMount( - name="cassandra-crt", - mount_path="/certs", - ) - entrypoint_volume_mount = client.V1VolumeMount( name="entrypoint-volume", mount_path="/bin/entrypoint", @@ -210,14 +202,18 @@ def setUpValidatorPods(time_intervals, logging, worker_image, worker_tag): resources=resource_requirements_container, env=env_vars, image_pull_policy=os.environ.get("IMAGE_PULL_POLICY", "IfNotPresent"), - volume_mounts=[auth_volume_mount, entrypoint_volume_mount, cassandra_ssl_volume_mount], + volume_mounts=[ + auth_volume_mount, + entrypoint_volume_mount, + ], ) # Define the init container init_container = client.V1Container( name="delegation-verify-init", image=f"{worker_image}:{worker_tag}", - command=["/bin/authenticate.sh"], + # command=["/bin/authenticate.sh"], + command=["ls"], env=env_vars, image_pull_policy=os.environ.get("IMAGE_PULL_POLICY", "IfNotPresent"), volume_mounts=[auth_volume_mount], @@ -236,7 +232,7 @@ def setUpValidatorPods(time_intervals, logging, worker_image, worker_tag): containers=[container], restart_policy="Never", service_account_name=service_account_name, - volumes=[auth_volume, entrypoint_volume, cassandra_ssl_volume], + volumes=[auth_volume, entrypoint_volume], ) ), ), @@ -245,7 +241,9 @@ def setUpValidatorPods(time_intervals, logging, worker_image, worker_tag): # Create the job and configmap in Kubernetes try: api_batch.create_namespaced_job(namespace, job) - logging.info(f"Job {job_name} created in namespace {namespace}") + logging.info( + f"Job {job_name} created in namespace {namespace}; start: {datetime_formatter(mini_batch[0])}, end: {datetime_formatter(mini_batch[1])}." + ) jobs.append(job_name) except Exception as e: logging.error(f"Error creating job {job_name}: {e}") @@ -297,6 +295,8 @@ def setUpValidatorProcesses(time_intervals, logging, worker_image, worker_tag): "-e", "CASSANDRA_PASSWORD", "-e", + "AWS_KEYSPACE", + "-e", "AWS_ACCESS_KEY_ID", "-e", "AWS_SECRET_ACCESS_KEY", @@ -310,12 +310,10 @@ def setUpValidatorProcesses(time_intervals, logging, worker_image, worker_tag): "SSL_CERTFILE=/var/ssl/ssl-cert.crt", "-e", "CASSANDRA_USE_SSL=1", - "-e", - "CQLSH=/bin/cqlsh-expansion", image, - "cassandra", - "--keyspace", - os.environ.get("AWS_KEYSPACE"), + # "cassandra", + # "--keyspace", + # os.environ.get("AWS_KEYSPACE"), f"{datetime_formatter(mini_batch[0])}", f"{datetime_formatter(mini_batch[1])}", ] diff --git a/uptime_service_validation/database/createDB.sql b/uptime_service_validation/database/createDB.sql index a28ca08..599a63b 100644 --- a/uptime_service_validation/database/createDB.sql +++ b/uptime_service_validation/database/createDB.sql @@ -40,7 +40,7 @@ CREATE TABLE nodes ( block_producer_key TEXT, updated_at TIMESTAMPTZ(6), score INT, - score_percent NUMERIC(6,2), + score_percent NUMERIC(10,2), discord_id TEXT, email_id TEXT, application_status BOOLEAN @@ -75,7 +75,7 @@ CREATE TABLE score_history ( node_id INT, score_at TIMESTAMP(6), score INT, - score_percent NUMERIC(6,2), + score_percent NUMERIC(10,2), CONSTRAINT fk_nodes FOREIGN KEY(node_id) REFERENCES nodes(id)