Skip to content

Commit

Permalink
Merge pull request #38 from MinaFoundation/sharding+cassandra-creds-c…
Browse files Browse the repository at this point in the history
…oordinator

PM-1284 Sharding with 600 per 24h + ExponentialBackOffRetryPolicy
  • Loading branch information
piotr-iohk authored Mar 14, 2024
2 parents f5b853a + 77b8b36 commit de8c9b9
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 91 deletions.
48 changes: 48 additions & 0 deletions tests/test_aws_keyspaces_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from uptime_service_validation.coordinator.aws_keyspaces_client import (
AWSKeyspacesClient,
ShardCalculator,
)


Expand All @@ -19,3 +20,50 @@ def test_get_submitted_at_date_list():
end = datetime(2023, 11, 8, 0, 0, 0)
result = AWSKeyspacesClient.get_submitted_at_date_list(start, end)
assert result == ["2023-11-06", "2023-11-07", "2023-11-08"]


def test_calculate_shard():
matrix = [
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 2, 59, 1],
[12, 1, 3, 300],
[15, 10, 0, 379],
[23, 22, 23, 584],
[23, 59, 59, 599],
]
for i in matrix:
hour = i[0]
minute = i[1]
second = i[2]
expected = i[3]
assert ShardCalculator.calculate_shard(hour, minute, second) == expected


def test_calculate_shards_in_range():
# 1 between days
start_time = datetime(2024, 2, 29, 23, 58, 29)
end_time = datetime(2024, 3, 1, 0, 3, 29)
expected_cql_statement = "shard in (0,1,599)"
result_cql_statement = ShardCalculator.calculate_shards_in_range(
start_time, end_time
)
assert result_cql_statement == expected_cql_statement

# 2 within the same day
start_time = datetime(2024, 2, 29, 12, 58, 1)
end_time = datetime(2024, 2, 29, 12, 59, 59)
expected_cql_statement = "shard in (324)"
result_cql_statement = ShardCalculator.calculate_shards_in_range(
start_time, end_time
)
assert result_cql_statement == expected_cql_statement

# 2 shard boundary
start_time = datetime(2024, 2, 29, 0, 0, 0)
end_time = datetime(2024, 2, 29, 0, 2, 24)
expected_cql_statement = "shard in (0,1)"
result_cql_statement = ShardCalculator.calculate_shards_in_range(
start_time, end_time
)
assert result_cql_statement == expected_cql_statement
137 changes: 121 additions & 16 deletions uptime_service_validation/coordinator/aws_keyspaces_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import boto3
import time
import random
from cassandra import ProtocolVersion
from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster, ExecutionProfile, EXEC_PROFILE_DEFAULT
from cassandra_sigv4.auth import SigV4AuthProvider
from cassandra.policies import DCAwareRoundRobinPolicy
from cassandra.policies import DCAwareRoundRobinPolicy, RetryPolicy
from ssl import SSLContext, CERT_REQUIRED, PROTOCOL_TLS_CLIENT
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from typing import Optional, ByteString, List

import pandas as pd
Expand Down Expand Up @@ -58,20 +60,41 @@ def __init__(self):
self.aws_ssl_certificate_path = os.environ.get("SSL_CERTFILE")
self.aws_region = self.cassandra_host.split(".")[1]
self.ssl_context = self._create_ssl_context()
self.request_timeout = 20.0

if self.cassandra_user and self.cassandra_pass:
self.auth_provider = PlainTextAuthProvider(
username=self.cassandra_user, password=self.cassandra_pass
)
profile = ExecutionProfile(
# assuming this is for hosted Cassandra, load balancing policy to be determined
# load_balancing_policy=DCAwareRoundRobinPolicy(local_dc=self.aws_region),
retry_policy=ExponentialBackOffRetryPolicy(),
request_timeout=self.request_timeout,
)
self.cluster = Cluster(
[self.cassandra_host],
ssl_context=self.ssl_context,
auth_provider=self.auth_provider,
port=int(self.cassandra_port),
execution_profiles={EXEC_PROFILE_DEFAULT: profile},
protocol_version=ProtocolVersion.V4,
)
else:
self.auth_provider = self._create_sigv4auth_provider()
self.cluster = self._create_cluster()
profile = ExecutionProfile(
load_balancing_policy=DCAwareRoundRobinPolicy(local_dc=self.aws_region),
retry_policy=ExponentialBackOffRetryPolicy(),
request_timeout=self.request_timeout,
)
self.cluster = Cluster(
[self.cassandra_host],
ssl_context=self.ssl_context,
auth_provider=self.auth_provider,
port=int(self.cassandra_port),
execution_profiles={EXEC_PROFILE_DEFAULT: profile},
protocol_version=ProtocolVersion.V4,
)

def _create_ssl_context(self):
ssl_context = SSLContext(PROTOCOL_TLS_CLIENT)
Expand Down Expand Up @@ -118,19 +141,6 @@ def _create_sigv4auth_provider(self):
)
return SigV4AuthProvider(boto_session)

def _create_cluster(self):
profile = ExecutionProfile(
load_balancing_policy=DCAwareRoundRobinPolicy(local_dc=self.aws_region)
)
return Cluster(
[self.cassandra_host],
ssl_context=self.ssl_context,
auth_provider=self.auth_provider,
port=int(self.cassandra_port),
execution_profiles={EXEC_PROFILE_DEFAULT: profile},
protocol_version=ProtocolVersion.V4,
)

def connect(self):
self.session = self.cluster.connect()

Expand Down Expand Up @@ -208,6 +218,10 @@ def get_submissions(
submitted_at_start, submitted_at_end
)

shard_condition = ShardCalculator.calculate_shards_in_range(
submitted_at_start, submitted_at_end
)

if len(submitted_at_date_list) == 1:
submitted_at_date = submitted_at_date_list[0]
else:
Expand All @@ -224,6 +238,10 @@ def get_submissions(
parameters.append(submitted_at_date)
elif submitted_at_dates:
conditions.append(f"submitted_at_date IN ({submitted_at_dates})")

# Add shard condition here since we have a submitted_at_date or submitted_at_dates
conditions.append(shard_condition)

if submitted_at_start:
start_operator = ">=" if start_inclusive else ">"
conditions.append(f"submitted_at {start_operator} %s")
Expand Down Expand Up @@ -273,6 +291,93 @@ def close(self):
self.cluster.shutdown()


class ExponentialBackOffRetryPolicy(RetryPolicy):
def __init__(self, base_delay=0.1, max_delay=10, max_retries=10):
self.base_delay = base_delay # seconds
self.max_delay = max_delay # seconds
self.max_retries = max_retries

def get_backoff_time(self, retry_num):
# Calculate exponential backoff time
delay = min(self.max_delay, self.base_delay * (2**retry_num))
# Add some randomness to avoid thundering herd problem
jitter = random.uniform(0, 0.1) * delay
return delay + jitter

def on_read_timeout(
self,
query,
consistency,
required_responses,
received_responses,
data_retrieved,
retry_num,
):
if retry_num >= self.max_retries:
return (self.RETHROW, None)
time.sleep(self.get_backoff_time(retry_num))
return (self.RETRY, consistency)

def on_write_timeout(
self,
query,
consistency,
write_type,
required_responses,
received_responses,
retry_num,
):
if retry_num >= self.max_retries:
return (self.RETHROW, None)
time.sleep(self.get_backoff_time(retry_num))
return (self.RETRY, consistency)

def on_unavailable(
self, query, consistency, required_replica, alive_replica, retry_num
):
if retry_num >= self.max_retries:
return (self.RETHROW, None)
time.sleep(self.get_backoff_time(retry_num))
return (self.RETRY_NEXT_HOST, None)


class ShardCalculator:
@classmethod
def calculate_shard(cls, hour, minute, second):
return (3600 * hour + 60 * minute + second) // 144

@classmethod
def calculate_shards_in_range(cls, start_time, end_time):
shards = set()
current_time = start_time

while current_time < end_time:
shard = cls.calculate_shard(
current_time.hour, current_time.minute, current_time.second
)
shards.add(shard)
# Move to the next second
current_time += timedelta(seconds=1)

# Check if endTime falls exactly on a new shard boundary and add it if necessary
end_shard = cls.calculate_shard(end_time.hour, end_time.minute, end_time.second)
if end_shard not in shards:
# Check if end_time is exactly on the boundary of a new shard
total_seconds_end = (
(end_time.hour * 3600) + (end_time.minute * 60) + end_time.second
)
if total_seconds_end % 144 == 0:
shards.add(end_shard)

# Convert the set of unique shards into a sorted list for readability
shards_list = sorted(list(shards))
# Format the shards into a CQL statement string
shards_list = sorted(list(shards)) # Sort the shards for readability
shards_str = ",".join(map(str, shards_list))
cql_statement = f"shard in ({shards_str})"
return cql_statement


# Usage Example
if __name__ == "__main__":
client = AWSKeyspacesClient()
Expand Down
Loading

0 comments on commit de8c9b9

Please sign in to comment.