diff --git a/src/dht_extension.cpp b/src/dht_extension.cpp index aad32b2..cee3e45 100644 --- a/src/dht_extension.cpp +++ b/src/dht_extension.cpp @@ -19,16 +19,39 @@ namespace duckdb { +// Global socket path variable +static std::string g_dhtd_socket; + +// Socket path getter with environment fallback static const char* GetDhtdSocketPath() { static const char* env_socket = std::getenv("DUCKDB_DHTD_SOCKET"); if (env_socket != nullptr && strlen(env_socket) > 0) { - return env_socket; + g_dhtd_socket = env_socket; + } else { + g_dhtd_socket = "/tmp/dhtd.sock"; } - return "/tmp/dhtd.sock"; + return g_dhtd_socket.c_str(); } static const char* DEFAULT_DHTD_SOCKET = GetDhtdSocketPath(); +// Setting function +static void DHTSetSocketFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &path_vector = args.data[0]; + + UnaryExecutor::Execute( + path_vector, result, args.size(), + [&](string_t path) { + std::string new_path = path.GetString(); + if (new_path.empty()) { + throw std::runtime_error("Socket path cannot be empty"); + } + g_dhtd_socket = new_path; + DEFAULT_DHTD_SOCKET = g_dhtd_socket.c_str(); + return StringVector::AddString(result, "DHT socket path set to: " + new_path); + }); +} + // Function to compute SHA-256 hash of input using modern OpenSSL API static std::string ComputeSHA256(const std::string& input) { EVP_MD_CTX* ctx = EVP_MD_CTX_new(); @@ -55,7 +78,7 @@ static std::string ComputeSHA256(const std::string& input) { } EVP_MD_CTX_free(ctx); - + std::stringstream ss; for(unsigned int i = 0; i < lengthOfHash; i++) { ss << std::hex << std::setw(2) << std::setfill('0') << (int)hash[i]; @@ -79,7 +102,7 @@ static std::string ValidateOrHashInput(const std::string& input) { return input; } } - + // Not a valid hash, compute hash from input return ComputeSHA256(input); } @@ -122,7 +145,7 @@ static string_t CommunicateWithDhtd(Vector& result, const std::string& command) // DHT Search function static void DHTSearchFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &id_vector = args.data[0]; - + UnaryExecutor::Execute( id_vector, result, args.size(), [&](string_t id) { @@ -135,7 +158,7 @@ static void DHTSearchFunction(DataChunk &args, ExpressionState &state, Vector &r // DHT Query function (search + results) static void DHTQueryFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &id_vector = args.data[0]; - + UnaryExecutor::Execute( id_vector, result, args.size(), [&](string_t id) { @@ -149,7 +172,7 @@ static void DHTQueryFunction(DataChunk &args, ExpressionState &state, Vector &re static void DHTAnnounceFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &id_vector = args.data[0]; auto &port_vector = args.data[1]; - + UnaryExecutor::Execute( id_vector, result, args.size(), [&](string_t id) { @@ -163,7 +186,7 @@ static void DHTAnnounceFunction(DataChunk &args, ExpressionState &state, Vector // DHT Stop Announce function static void DHTStopAnnounceFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &id_vector = args.data[0]; - + UnaryExecutor::Execute( id_vector, result, args.size(), [&](string_t id) { @@ -176,7 +199,7 @@ static void DHTStopAnnounceFunction(DataChunk &args, ExpressionState &state, Vec // DHT Status function static void DHTStatusFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &dummy_vector = args.data[0]; - + UnaryExecutor::Execute( dummy_vector, result, args.size(), [&](string_t dummy) { @@ -187,7 +210,7 @@ static void DHTStatusFunction(DataChunk &args, ExpressionState &state, Vector &r // DHT Add Peer function static void DHTPeerFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &address_vector = args.data[0]; - + UnaryExecutor::Execute( address_vector, result, args.size(), [&](string_t address) { @@ -199,7 +222,7 @@ static void DHTPeerFunction(DataChunk &args, ExpressionState &state, Vector &res // Compute hash function (exposed for testing/verification) static void DHTHashFunction(DataChunk &args, ExpressionState &state, Vector &result) { auto &input_vector = args.data[0]; - + UnaryExecutor::Execute( input_vector, result, args.size(), [&](string_t input) { @@ -215,7 +238,7 @@ struct DhtResultsBindData : public TableFunctionData { struct DhtResultsGlobalState : public GlobalTableFunctionState { DhtResultsGlobalState() : position(0) {} - + std::vector> results; idx_t position; @@ -227,17 +250,14 @@ struct DhtResultsGlobalState : public GlobalTableFunctionState { static unique_ptr DhtResultsBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { auto result = make_uniq(); - - // Validate we have exactly one parameter + if (input.inputs.size() != 1 || input.inputs[0].IsNull()) { throw std::runtime_error("DHT results table function requires one non-null parameter"); } - // Get the hash parameter result->hash = input.inputs[0].ToString(); result->hash = ValidateOrHashInput(result->hash); - // Define the table structure return_types = {LogicalType::VARCHAR, LogicalType::INTEGER}; names = {"address", "port"}; @@ -247,7 +267,7 @@ static unique_ptr DhtResultsBind(ClientContext &context, TableFunc static void DhtResultsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { auto &bind_data = data_p.bind_data->Cast(); auto &state = data_p.global_state->Cast(); - + // If this is the first call, fetch results if (state.position == 0) { // Get results from dhtd @@ -291,8 +311,7 @@ static void DhtResultsFunction(ClientContext &context, TableFunctionInput &data_ while (!line.empty() && (line.back() == '\r' || line.back() == ' ' || line.back() == '\t')) { line.pop_back(); } - - // Skip empty lines + if (line.empty()) { continue; } @@ -327,7 +346,7 @@ static void DhtResultsFunction(ClientContext &context, TableFunctionInput &data_ while (state.position < state.results.size() && count < STANDARD_VECTOR_SIZE) { const auto& result = state.results[state.position]; - if (!result.first.empty()) { // Extra check to ensure no empty addresses + if (!result.first.empty()) { addr_data[count] = StringVector::AddString(output.data[0], result.first); port_data[count] = result.second; count++; @@ -338,38 +357,181 @@ static void DhtResultsFunction(ClientContext &context, TableFunctionInput &data_ output.SetCardinality(count); } +// Status Function +// Add these structures in dht_extension.cpp: +struct DhtTableStatusBindData : public TableFunctionData { +}; + +struct DhtStatusGlobalState : public GlobalTableFunctionState { + DhtStatusGlobalState() : position(0) {} + DhtStatusInfo status; + idx_t position; + + static unique_ptr Init(ClientContext &context, TableFunctionInitInput &input) { + return make_uniq(); + } +}; + +// Add the table function +static void DhtTableFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &state = data_p.global_state->Cast(); + + // Only fetch status once + if (state.position == 0) { + // Get status from dhtd via socket + int sock = socket(AF_UNIX, SOCK_STREAM, 0); + if (sock < 0) { + throw std::runtime_error("Failed to create socket: " + std::string(strerror(errno))); + } + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, GetDhtdSocketPath(), sizeof(addr.sun_path) - 1); + + if (connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) { + close(sock); + throw std::runtime_error("Failed to connect to dhtd: " + std::string(strerror(errno))); + } + + // Send status command + std::string cmd = "status\n"; + if (write(sock, cmd.c_str(), cmd.length()) < 0) { + close(sock); + throw std::runtime_error("Failed to send command: " + std::string(strerror(errno))); + } + + // Read response + std::string response; + char buffer[1024]; + while (true) { + ssize_t bytes = read(sock, buffer, sizeof(buffer)); + if (bytes <= 0) break; + response.append(buffer, bytes); + } + close(sock); + + // Parse status using our new parser + ParseDhtStatus(response, state.status); + } + + if (state.position == 0) { // Only output one row + output.SetCardinality(1); + + int col = 0; + FlatVector::GetData(output.data[col++])[0] = StringVector::AddString(output.data[0], state.status.version); + FlatVector::GetData(output.data[col++])[0] = StringVector::AddString(output.data[1], state.status.node_id); + FlatVector::GetData(output.data[col++])[0] = StringVector::AddString(output.data[2], state.status.uptime); + FlatVector::GetData(output.data[col++])[0] = StringVector::AddString(output.data[3], state.status.listen_info); + FlatVector::GetData(output.data[col++])[0] = state.status.port; + FlatVector::GetData(output.data[col++])[0] = state.status.ipv4_nodes; + FlatVector::GetData(output.data[col++])[0] = state.status.ipv4_good_nodes; + FlatVector::GetData(output.data[col++])[0] = state.status.ipv6_nodes; + FlatVector::GetData(output.data[col++])[0] = state.status.ipv6_good_nodes; + FlatVector::GetData(output.data[col++])[0] = state.status.storage_entries; + FlatVector::GetData(output.data[col++])[0] = state.status.storage_addresses; + FlatVector::GetData(output.data[col++])[0] = state.status.ipv4_searches; + FlatVector::GetData(output.data[col++])[0] = state.status.ipv4_searches_done; + FlatVector::GetData(output.data[col++])[0] = state.status.ipv6_searches; + FlatVector::GetData(output.data[col++])[0] = state.status.ipv6_searches_done; + FlatVector::GetData(output.data[col++])[0] = state.status.announcements; + FlatVector::GetData(output.data[col++])[0] = state.status.blocklist; + FlatVector::GetData(output.data[col++])[0] = state.status.traffic_in; + FlatVector::GetData(output.data[col++])[0] = StringVector::AddString(output.data[18], state.status.traffic_in_rate); + FlatVector::GetData(output.data[col++])[0] = state.status.traffic_out; + FlatVector::GetData(output.data[col++])[0] = StringVector::AddString(output.data[20], state.status.traffic_out_rate); + + state.position++; + } else { + output.SetCardinality(0); + } +} + +static unique_ptr DhtTableStatusBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + // Define the table structure + return_types = { + LogicalType::VARCHAR, // version + LogicalType::VARCHAR, // node_id + LogicalType::VARCHAR, // uptime + LogicalType::VARCHAR, // listen_info + LogicalType::INTEGER, // port + LogicalType::INTEGER, // ipv4_nodes + LogicalType::INTEGER, // ipv4_good_nodes + LogicalType::INTEGER, // ipv6_nodes + LogicalType::INTEGER, // ipv6_good_nodes + LogicalType::INTEGER, // storage_entries + LogicalType::INTEGER, // storage_addresses + LogicalType::INTEGER, // ipv4_searches + LogicalType::INTEGER, // ipv4_searches_done + LogicalType::INTEGER, // ipv6_searches + LogicalType::INTEGER, // ipv6_searches_done + LogicalType::INTEGER, // announcements + LogicalType::INTEGER, // blocklist + LogicalType::DOUBLE, // traffic_in + LogicalType::VARCHAR, // traffic_in_rate + LogicalType::DOUBLE, // traffic_out + LogicalType::VARCHAR // traffic_out_rate + }; + + names = { + "version", "node_id", "uptime", "listen_info", "port", + "ipv4_nodes", "ipv4_good_nodes", "ipv6_nodes", "ipv6_good_nodes", + "storage_entries", "storage_addresses", + "ipv4_searches", "ipv4_searches_done", "ipv6_searches", "ipv6_searches_done", + "announcements", "blocklist", + "traffic_in", "traffic_in_rate", "traffic_out", "traffic_out_rate" + }; + + return make_uniq(); +} + + + + static void LoadInternal(DatabaseInstance &instance) { // Register scalar functions - ExtensionUtil::RegisterFunction(instance, + ExtensionUtil::RegisterFunction(instance, ScalarFunction("dht_search", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DHTSearchFunction)); - + ExtensionUtil::RegisterFunction(instance, ScalarFunction("dht_query", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DHTQueryFunction)); - + ExtensionUtil::RegisterFunction(instance, - ScalarFunction("dht_announce", {LogicalType::VARCHAR, LogicalType::INTEGER}, + ScalarFunction("dht_announce", {LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::VARCHAR, DHTAnnounceFunction)); - + ExtensionUtil::RegisterFunction(instance, - ScalarFunction("dht_stop_announce", {LogicalType::VARCHAR}, + ScalarFunction("dht_stop_announce", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DHTStopAnnounceFunction)); - + ExtensionUtil::RegisterFunction(instance, - ScalarFunction("dht_status", {LogicalType::VARCHAR}, + ScalarFunction("dht_status", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DHTStatusFunction)); - + ExtensionUtil::RegisterFunction(instance, - ScalarFunction("dht_peer", {LogicalType::VARCHAR}, + ScalarFunction("dht_peer", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DHTPeerFunction)); - + ExtensionUtil::RegisterFunction(instance, - ScalarFunction("dht_hash", {LogicalType::VARCHAR}, + ScalarFunction("dht_hash", {LogicalType::VARCHAR}, LogicalType::VARCHAR, DHTHashFunction)); // Register table function - TableFunction dht_results("dht_results", {LogicalType::VARCHAR}, DhtResultsFunction, DhtResultsBind, + TableFunction dht_results("dht_results", {LogicalType::VARCHAR}, DhtResultsFunction, DhtResultsBind, DhtResultsGlobalState::Init); ExtensionUtil::RegisterFunction(instance, dht_results); + + // Status Table function + TableFunction dht_status("dht_status", {}, DhtTableFunction, DhtTableStatusBind, + DhtStatusGlobalState::Init); + ExtensionUtil::RegisterFunction(instance, dht_status); + + // Register socket setting function + ExtensionUtil::RegisterFunction(instance, + ScalarFunction("dht_set_socket", {LogicalType::VARCHAR}, + LogicalType::VARCHAR, DHTSetSocketFunction)); + } void DhtExtension::Load(DuckDB &db) {