Skip to content

Commit

Permalink
dht_status() as table function
Browse files Browse the repository at this point in the history
  • Loading branch information
lmangani authored Nov 13, 2024
1 parent 3807330 commit c98d267
Showing 1 changed file with 195 additions and 33 deletions.
228 changes: 195 additions & 33 deletions src/dht_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,39 @@

namespace duckdb {

// Global socket path variable
static std::string g_dhtd_socket;

// Socket path getter with environment fallback
static const char* GetDhtdSocketPath() {
static const char* env_socket = std::getenv("DUCKDB_DHTD_SOCKET");
if (env_socket != nullptr && strlen(env_socket) > 0) {
return env_socket;
g_dhtd_socket = env_socket;
} else {
g_dhtd_socket = "/tmp/dhtd.sock";
}
return "/tmp/dhtd.sock";
return g_dhtd_socket.c_str();
}

static const char* DEFAULT_DHTD_SOCKET = GetDhtdSocketPath();

// Setting function
static void DHTSetSocketFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &path_vector = args.data[0];

UnaryExecutor::Execute<string_t, string_t>(
path_vector, result, args.size(),
[&](string_t path) {
std::string new_path = path.GetString();
if (new_path.empty()) {
throw std::runtime_error("Socket path cannot be empty");
}
g_dhtd_socket = new_path;
DEFAULT_DHTD_SOCKET = g_dhtd_socket.c_str();
return StringVector::AddString(result, "DHT socket path set to: " + new_path);
});
}

// Function to compute SHA-256 hash of input using modern OpenSSL API
static std::string ComputeSHA256(const std::string& input) {
EVP_MD_CTX* ctx = EVP_MD_CTX_new();
Expand All @@ -55,7 +78,7 @@ static std::string ComputeSHA256(const std::string& input) {
}

EVP_MD_CTX_free(ctx);

std::stringstream ss;
for(unsigned int i = 0; i < lengthOfHash; i++) {
ss << std::hex << std::setw(2) << std::setfill('0') << (int)hash[i];
Expand All @@ -79,7 +102,7 @@ static std::string ValidateOrHashInput(const std::string& input) {
return input;
}
}

// Not a valid hash, compute hash from input
return ComputeSHA256(input);
}
Expand Down Expand Up @@ -122,7 +145,7 @@ static string_t CommunicateWithDhtd(Vector& result, const std::string& command)
// DHT Search function
static void DHTSearchFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &id_vector = args.data[0];

UnaryExecutor::Execute<string_t, string_t>(
id_vector, result, args.size(),
[&](string_t id) {
Expand All @@ -135,7 +158,7 @@ static void DHTSearchFunction(DataChunk &args, ExpressionState &state, Vector &r
// DHT Query function (search + results)
static void DHTQueryFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &id_vector = args.data[0];

UnaryExecutor::Execute<string_t, string_t>(
id_vector, result, args.size(),
[&](string_t id) {
Expand All @@ -149,7 +172,7 @@ static void DHTQueryFunction(DataChunk &args, ExpressionState &state, Vector &re
static void DHTAnnounceFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &id_vector = args.data[0];
auto &port_vector = args.data[1];

UnaryExecutor::Execute<string_t, string_t>(
id_vector, result, args.size(),
[&](string_t id) {
Expand All @@ -163,7 +186,7 @@ static void DHTAnnounceFunction(DataChunk &args, ExpressionState &state, Vector
// DHT Stop Announce function
static void DHTStopAnnounceFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &id_vector = args.data[0];

UnaryExecutor::Execute<string_t, string_t>(
id_vector, result, args.size(),
[&](string_t id) {
Expand All @@ -176,7 +199,7 @@ static void DHTStopAnnounceFunction(DataChunk &args, ExpressionState &state, Vec
// DHT Status function
static void DHTStatusFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &dummy_vector = args.data[0];

UnaryExecutor::Execute<string_t, string_t>(
dummy_vector, result, args.size(),
[&](string_t dummy) {
Expand All @@ -187,7 +210,7 @@ static void DHTStatusFunction(DataChunk &args, ExpressionState &state, Vector &r
// DHT Add Peer function
static void DHTPeerFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &address_vector = args.data[0];

UnaryExecutor::Execute<string_t, string_t>(
address_vector, result, args.size(),
[&](string_t address) {
Expand All @@ -199,7 +222,7 @@ static void DHTPeerFunction(DataChunk &args, ExpressionState &state, Vector &res
// Compute hash function (exposed for testing/verification)
static void DHTHashFunction(DataChunk &args, ExpressionState &state, Vector &result) {
auto &input_vector = args.data[0];

UnaryExecutor::Execute<string_t, string_t>(
input_vector, result, args.size(),
[&](string_t input) {
Expand All @@ -215,7 +238,7 @@ struct DhtResultsBindData : public TableFunctionData {

struct DhtResultsGlobalState : public GlobalTableFunctionState {
DhtResultsGlobalState() : position(0) {}

std::vector<std::pair<std::string, uint16_t>> results;
idx_t position;

Expand All @@ -227,17 +250,14 @@ struct DhtResultsGlobalState : public GlobalTableFunctionState {
static unique_ptr<FunctionData> DhtResultsBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
auto result = make_uniq<DhtResultsBindData>();

// Validate we have exactly one parameter

if (input.inputs.size() != 1 || input.inputs[0].IsNull()) {
throw std::runtime_error("DHT results table function requires one non-null parameter");
}

// Get the hash parameter
result->hash = input.inputs[0].ToString();
result->hash = ValidateOrHashInput(result->hash);

// Define the table structure
return_types = {LogicalType::VARCHAR, LogicalType::INTEGER};
names = {"address", "port"};

Expand All @@ -247,7 +267,7 @@ static unique_ptr<FunctionData> DhtResultsBind(ClientContext &context, TableFunc
static void DhtResultsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
auto &bind_data = data_p.bind_data->Cast<DhtResultsBindData>();
auto &state = data_p.global_state->Cast<DhtResultsGlobalState>();

// If this is the first call, fetch results
if (state.position == 0) {
// Get results from dhtd
Expand Down Expand Up @@ -291,8 +311,7 @@ static void DhtResultsFunction(ClientContext &context, TableFunctionInput &data_
while (!line.empty() && (line.back() == '\r' || line.back() == ' ' || line.back() == '\t')) {
line.pop_back();
}

// Skip empty lines

if (line.empty()) {
continue;
}
Expand Down Expand Up @@ -327,7 +346,7 @@ static void DhtResultsFunction(ClientContext &context, TableFunctionInput &data_

while (state.position < state.results.size() && count < STANDARD_VECTOR_SIZE) {
const auto& result = state.results[state.position];
if (!result.first.empty()) { // Extra check to ensure no empty addresses
if (!result.first.empty()) {
addr_data[count] = StringVector::AddString(output.data[0], result.first);
port_data[count] = result.second;
count++;
Expand All @@ -338,38 +357,181 @@ static void DhtResultsFunction(ClientContext &context, TableFunctionInput &data_
output.SetCardinality(count);
}

// Status Function
// Add these structures in dht_extension.cpp:
struct DhtTableStatusBindData : public TableFunctionData {
};

struct DhtStatusGlobalState : public GlobalTableFunctionState {
DhtStatusGlobalState() : position(0) {}
DhtStatusInfo status;
idx_t position;

static unique_ptr<GlobalTableFunctionState> Init(ClientContext &context, TableFunctionInitInput &input) {
return make_uniq<DhtStatusGlobalState>();
}
};

// Add the table function
static void DhtTableFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
auto &state = data_p.global_state->Cast<DhtStatusGlobalState>();

// Only fetch status once
if (state.position == 0) {
// Get status from dhtd via socket
int sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (sock < 0) {
throw std::runtime_error("Failed to create socket: " + std::string(strerror(errno)));
}

struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, GetDhtdSocketPath(), sizeof(addr.sun_path) - 1);

if (connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
close(sock);
throw std::runtime_error("Failed to connect to dhtd: " + std::string(strerror(errno)));
}

// Send status command
std::string cmd = "status\n";
if (write(sock, cmd.c_str(), cmd.length()) < 0) {
close(sock);
throw std::runtime_error("Failed to send command: " + std::string(strerror(errno)));
}

// Read response
std::string response;
char buffer[1024];
while (true) {
ssize_t bytes = read(sock, buffer, sizeof(buffer));
if (bytes <= 0) break;
response.append(buffer, bytes);
}
close(sock);

// Parse status using our new parser
ParseDhtStatus(response, state.status);
}

if (state.position == 0) { // Only output one row
output.SetCardinality(1);

int col = 0;
FlatVector::GetData<string_t>(output.data[col++])[0] = StringVector::AddString(output.data[0], state.status.version);
FlatVector::GetData<string_t>(output.data[col++])[0] = StringVector::AddString(output.data[1], state.status.node_id);
FlatVector::GetData<string_t>(output.data[col++])[0] = StringVector::AddString(output.data[2], state.status.uptime);
FlatVector::GetData<string_t>(output.data[col++])[0] = StringVector::AddString(output.data[3], state.status.listen_info);
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.port;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.ipv4_nodes;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.ipv4_good_nodes;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.ipv6_nodes;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.ipv6_good_nodes;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.storage_entries;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.storage_addresses;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.ipv4_searches;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.ipv4_searches_done;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.ipv6_searches;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.ipv6_searches_done;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.announcements;
FlatVector::GetData<int32_t>(output.data[col++])[0] = state.status.blocklist;
FlatVector::GetData<double>(output.data[col++])[0] = state.status.traffic_in;
FlatVector::GetData<string_t>(output.data[col++])[0] = StringVector::AddString(output.data[18], state.status.traffic_in_rate);
FlatVector::GetData<double>(output.data[col++])[0] = state.status.traffic_out;
FlatVector::GetData<string_t>(output.data[col++])[0] = StringVector::AddString(output.data[20], state.status.traffic_out_rate);

state.position++;
} else {
output.SetCardinality(0);
}
}

static unique_ptr<FunctionData> DhtTableStatusBind(ClientContext &context, TableFunctionBindInput &input,
vector<LogicalType> &return_types, vector<string> &names) {
// Define the table structure
return_types = {
LogicalType::VARCHAR, // version
LogicalType::VARCHAR, // node_id
LogicalType::VARCHAR, // uptime
LogicalType::VARCHAR, // listen_info
LogicalType::INTEGER, // port
LogicalType::INTEGER, // ipv4_nodes
LogicalType::INTEGER, // ipv4_good_nodes
LogicalType::INTEGER, // ipv6_nodes
LogicalType::INTEGER, // ipv6_good_nodes
LogicalType::INTEGER, // storage_entries
LogicalType::INTEGER, // storage_addresses
LogicalType::INTEGER, // ipv4_searches
LogicalType::INTEGER, // ipv4_searches_done
LogicalType::INTEGER, // ipv6_searches
LogicalType::INTEGER, // ipv6_searches_done
LogicalType::INTEGER, // announcements
LogicalType::INTEGER, // blocklist
LogicalType::DOUBLE, // traffic_in
LogicalType::VARCHAR, // traffic_in_rate
LogicalType::DOUBLE, // traffic_out
LogicalType::VARCHAR // traffic_out_rate
};

names = {
"version", "node_id", "uptime", "listen_info", "port",
"ipv4_nodes", "ipv4_good_nodes", "ipv6_nodes", "ipv6_good_nodes",
"storage_entries", "storage_addresses",
"ipv4_searches", "ipv4_searches_done", "ipv6_searches", "ipv6_searches_done",
"announcements", "blocklist",
"traffic_in", "traffic_in_rate", "traffic_out", "traffic_out_rate"
};

return make_uniq<DhtTableStatusBindData>();
}




static void LoadInternal(DatabaseInstance &instance) {
// Register scalar functions
ExtensionUtil::RegisterFunction(instance,
ExtensionUtil::RegisterFunction(instance,
ScalarFunction("dht_search", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DHTSearchFunction));

ExtensionUtil::RegisterFunction(instance,
ScalarFunction("dht_query", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DHTQueryFunction));

ExtensionUtil::RegisterFunction(instance,
ScalarFunction("dht_announce", {LogicalType::VARCHAR, LogicalType::INTEGER},
ScalarFunction("dht_announce", {LogicalType::VARCHAR, LogicalType::INTEGER},
LogicalType::VARCHAR, DHTAnnounceFunction));

ExtensionUtil::RegisterFunction(instance,
ScalarFunction("dht_stop_announce", {LogicalType::VARCHAR},
ScalarFunction("dht_stop_announce", {LogicalType::VARCHAR},
LogicalType::VARCHAR, DHTStopAnnounceFunction));

ExtensionUtil::RegisterFunction(instance,
ScalarFunction("dht_status", {LogicalType::VARCHAR},
ScalarFunction("dht_status", {LogicalType::VARCHAR},
LogicalType::VARCHAR, DHTStatusFunction));

ExtensionUtil::RegisterFunction(instance,
ScalarFunction("dht_peer", {LogicalType::VARCHAR},
ScalarFunction("dht_peer", {LogicalType::VARCHAR},
LogicalType::VARCHAR, DHTPeerFunction));

ExtensionUtil::RegisterFunction(instance,
ScalarFunction("dht_hash", {LogicalType::VARCHAR},
ScalarFunction("dht_hash", {LogicalType::VARCHAR},
LogicalType::VARCHAR, DHTHashFunction));

// Register table function
TableFunction dht_results("dht_results", {LogicalType::VARCHAR}, DhtResultsFunction, DhtResultsBind,
TableFunction dht_results("dht_results", {LogicalType::VARCHAR}, DhtResultsFunction, DhtResultsBind,
DhtResultsGlobalState::Init);
ExtensionUtil::RegisterFunction(instance, dht_results);

// Status Table function
TableFunction dht_status("dht_status", {}, DhtTableFunction, DhtTableStatusBind,
DhtStatusGlobalState::Init);
ExtensionUtil::RegisterFunction(instance, dht_status);

// Register socket setting function
ExtensionUtil::RegisterFunction(instance,
ScalarFunction("dht_set_socket", {LogicalType::VARCHAR},
LogicalType::VARCHAR, DHTSetSocketFunction));

}

void DhtExtension::Load(DuckDB &db) {
Expand Down

0 comments on commit c98d267

Please sign in to comment.