From 11622e83377901133e836551ddbe761f59a0142d Mon Sep 17 00:00:00 2001 From: David Leifker Date: Sat, 1 Mar 2025 10:15:45 -0600 Subject: [PATCH] feat(operations): ES and Kafka Operations Endpoints --- .../elasticsearch/ElasticSearchService.java | 17 + .../elasticsearch/query/ESSearchDAO.java | 42 ++ .../metadata/trace/KafkaTraceReader.java | 229 ++++++++- .../search/ElasticSearchServiceTest.java | 204 ++++++++ .../query/ESSearchDAORawEntityTest.java | 69 +++ .../trace/BaseKafkaTraceReaderTest.java | 303 ++++++++++++ ...ller.java => ElasticsearchController.java} | 62 ++- .../operations/kafka/KafkaController.java | 437 ++++++++++++++++++ .../operations/kafka/KafkaOffsetResponse.java | 68 +++ .../metadata/search/EntitySearchService.java | 10 + .../search/EntitySearchServiceTest.java | 7 + 11 files changed, 1440 insertions(+), 8 deletions(-) create mode 100644 metadata-io/src/test/java/com/linkedin/metadata/search/query/ESSearchDAORawEntityTest.java rename metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/elastic/{OperationsController.java => ElasticsearchController.java} (90%) create mode 100644 metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/kafka/KafkaController.java create mode 100644 metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/kafka/KafkaOffsetResponse.java diff --git a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/ElasticSearchService.java b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/ElasticSearchService.java index 8ec6d4c699e371..00b13e64e14769 100644 --- a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/ElasticSearchService.java +++ b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/ElasticSearchService.java @@ -31,6 +31,8 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.Getter; @@ -401,6 +403,21 @@ public Optional raw( return esSearchDAO.raw(opContext, indexName, jsonQuery); } + @Override + @Nonnull + public Map> raw( + @Nonnull OperationContext opContext, @Nonnull Set urns) { + return esSearchDAO.rawEntity(opContext, urns).entrySet().stream() + .flatMap( + entry -> + Optional.ofNullable(entry.getValue().getHits().getHits()) + .filter(hits -> hits.length > 0) + .map(hits -> Map.entry(entry.getKey(), hits[0])) + .stream()) + .map(entry -> Map.entry(entry.getKey(), entry.getValue().getSourceAsMap())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + @Override public int maxResultSize() { return ESUtils.MAX_RESULT_SIZE; diff --git a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java index 519322c5720802..840481934a730e 100644 --- a/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java +++ b/metadata-io/src/main/java/com/linkedin/metadata/search/elasticsearch/query/ESSearchDAO.java @@ -2,6 +2,7 @@ import static com.linkedin.metadata.Constants.*; import static com.linkedin.metadata.aspect.patch.template.TemplateUtil.*; +import static com.linkedin.metadata.timeseries.elastic.indexbuilder.MappingsBuilder.URN_FIELD; import static com.linkedin.metadata.utils.SearchUtil.*; import com.datahub.util.exception.ESQueryException; @@ -9,10 +10,12 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import com.linkedin.common.urn.Urn; import com.linkedin.data.template.LongMap; import com.linkedin.metadata.config.search.SearchConfiguration; import com.linkedin.metadata.config.search.custom.CustomSearchConfiguration; import com.linkedin.metadata.models.EntitySpec; +import com.linkedin.metadata.models.registry.EntityRegistry; import com.linkedin.metadata.query.AutoCompleteResult; import com.linkedin.metadata.query.filter.Filter; import com.linkedin.metadata.query.filter.SortCriterion; @@ -37,6 +40,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.annotation.Nonnull; @@ -58,6 +62,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; @@ -615,6 +620,43 @@ public Optional raw( }); } + public Map rawEntity(@Nonnull OperationContext opContext, Set urns) { + EntityRegistry entityRegistry = opContext.getEntityRegistry(); + Map specs = + urns.stream() + .flatMap( + urn -> + Optional.ofNullable(entityRegistry.getEntitySpec(urn.getEntityType())) + .map(spec -> Map.entry(urn, spec)) + .stream()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + return specs.entrySet().stream() + .map( + entry -> { + try { + String indexName = + opContext + .getSearchContext() + .getIndexConvention() + .getIndexName(entry.getValue()); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query( + QueryBuilders.termQuery(URN_FIELD, entry.getKey().toString())); + + SearchRequest searchRequest = new SearchRequest(indexName); + searchRequest.source(searchSourceBuilder); + + return Map.entry( + entry.getKey(), client.search(searchRequest, RequestOptions.DEFAULT)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + private boolean supportsPointInTime() { return pointInTimeCreationEnabled && ELASTICSEARCH_IMPLEMENTATION_ELASTICSEARCH.equalsIgnoreCase(elasticSearchImplementation); diff --git a/metadata-io/src/main/java/com/linkedin/metadata/trace/KafkaTraceReader.java b/metadata-io/src/main/java/com/linkedin/metadata/trace/KafkaTraceReader.java index 8b045084c2c5c2..92c5c0db10fba3 100644 --- a/metadata-io/src/main/java/com/linkedin/metadata/trace/KafkaTraceReader.java +++ b/metadata-io/src/main/java/com/linkedin/metadata/trace/KafkaTraceReader.java @@ -11,6 +11,7 @@ import io.datahubproject.openapi.v1.models.TraceStorageStatus; import io.datahubproject.openapi.v1.models.TraceWriteStatus; import java.time.Duration; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -62,7 +63,14 @@ public abstract class KafkaTraceReader { private final Cache offsetCache = Caffeine.newBuilder() .maximumSize(100) // unlikely to have more than 100 partitions - .expireAfterWrite(Duration.ofMinutes(5)) // Shorter expiry for offsets + .expireAfterWrite( + Duration.ofMinutes(5)) // Short expiry since end offsets change frequently + .build(); + private final Cache endOffsetCache = + Caffeine.newBuilder() + .maximumSize(100) // Match the size of offsetCache + .expireAfterWrite( + Duration.ofSeconds(5)) // Short expiry since end offsets change frequently .build(); public KafkaTraceReader( @@ -218,6 +226,225 @@ public Map> tracePendingStatuses( } } + /** + * Returns the current consumer group offsets for all partitions of the topic. + * + * @param skipCache Whether to skip the cache when fetching offsets + * @return Map of TopicPartition to OffsetAndMetadata, empty map if no offsets found or error + * occurs + */ + public Map getAllPartitionOffsets(boolean skipCache) { + final String consumerGroupId = getConsumerGroupId(); + if (consumerGroupId == null) { + log.warn("Cannot get partition offsets: consumer group ID is null"); + return Collections.emptyMap(); + } + + try { + // Get all topic partitions first + Map topicInfo = + adminClient + .describeTopics(Collections.singletonList(getTopicName())) + .all() + .get(timeoutSeconds, TimeUnit.SECONDS); + + if (topicInfo == null || !topicInfo.containsKey(getTopicName())) { + log.error("Failed to get topic information for topic: {}", getTopicName()); + return Collections.emptyMap(); + } + + // Create a list of all TopicPartitions + List allPartitions = + topicInfo.get(getTopicName()).partitions().stream() + .map(partitionInfo -> new TopicPartition(getTopicName(), partitionInfo.partition())) + .collect(Collectors.toList()); + + // For each partition that exists in the cache and wasn't requested to skip, + // pre-populate the result map + Map result = new HashMap<>(); + if (!skipCache) { + for (TopicPartition partition : allPartitions) { + OffsetAndMetadata cached = offsetCache.getIfPresent(partition); + if (cached != null) { + result.put(partition, cached); + } + } + } + + // If we have all partitions from cache and aren't skipping cache, return early + if (!skipCache && result.size() == allPartitions.size()) { + return result; + } + + // Get all offsets for the consumer group + ListConsumerGroupOffsetsResult offsetsResult = + adminClient.listConsumerGroupOffsets(consumerGroupId); + if (offsetsResult == null) { + log.error("Failed to get consumer group offsets for group: {}", consumerGroupId); + return result; + } + + Map fetchedOffsets = + offsetsResult.partitionsToOffsetAndMetadata().get(timeoutSeconds, TimeUnit.SECONDS); + + if (fetchedOffsets == null) { + log.error("Null offsets returned for consumer group: {}", consumerGroupId); + return result; + } + + // Filter to only keep offsets for our topic + Map topicOffsets = + fetchedOffsets.entrySet().stream() + .filter(entry -> entry.getKey().topic().equals(getTopicName())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + // Update the cache for each offset + for (Map.Entry entry : topicOffsets.entrySet()) { + offsetCache.put(entry.getKey(), entry.getValue()); + } + + // Return all offsets + return topicOffsets; + } catch (Exception e) { + log.error("Error fetching all partition offsets for topic {}", getTopicName(), e); + return Collections.emptyMap(); + } + } + + /** + * Returns the end offsets (latest offsets) for all partitions of the topic. + * + * @param skipCache Whether to skip the cache when fetching end offsets + * @return Map of TopicPartition to end offset, empty map if no offsets found or error occurs + */ + public Map getEndOffsets(boolean skipCache) { + try { + // Get all topic partitions first (reuse the same approach as in getAllPartitionOffsets) + Map topicInfo = + adminClient + .describeTopics(Collections.singletonList(getTopicName())) + .all() + .get(timeoutSeconds, TimeUnit.SECONDS); + + if (topicInfo == null || !topicInfo.containsKey(getTopicName())) { + log.error("Failed to get topic information for topic: {}", getTopicName()); + return Collections.emptyMap(); + } + + // Create a list of all TopicPartitions + List allPartitions = + topicInfo.get(getTopicName()).partitions().stream() + .map(partitionInfo -> new TopicPartition(getTopicName(), partitionInfo.partition())) + .collect(Collectors.toList()); + + // Pre-populate result map from cache if not skipping cache + Map result = new HashMap<>(); + if (!skipCache) { + for (TopicPartition partition : allPartitions) { + Long cached = endOffsetCache.getIfPresent(partition); + if (cached != null) { + result.put(partition, cached); + } + } + + // If we have all partitions from cache and aren't skipping cache, return early + if (result.size() == allPartitions.size()) { + return result; + } + } else { + // If skipping cache, invalidate all entries for these partitions + for (TopicPartition partition : allPartitions) { + endOffsetCache.invalidate(partition); + } + } + + // Fetch missing end offsets using a consumer + try (Consumer consumer = consumerSupplier.get()) { + // Determine which partitions we need to fetch + List partitionsToFetch = + allPartitions.stream() + .filter(partition -> skipCache || !result.containsKey(partition)) + .collect(Collectors.toList()); + + if (!partitionsToFetch.isEmpty()) { + // Assign partitions to the consumer + consumer.assign(partitionsToFetch); + + // Fetch end offsets for all partitions at once + Map fetchedEndOffsets = consumer.endOffsets(partitionsToFetch); + + // Update the cache and result map + for (Map.Entry entry : fetchedEndOffsets.entrySet()) { + endOffsetCache.put(entry.getKey(), entry.getValue()); + result.put(entry.getKey(), entry.getValue()); + } + } + } + + return result; + } catch (Exception e) { + log.error("Error fetching end offsets for topic {}", getTopicName(), e); + return Collections.emptyMap(); + } + } + + /** + * Returns the end offsets for a specific set of partitions. + * + * @param partitions Collection of TopicPartitions to get end offsets for + * @param skipCache Whether to skip the cache when fetching end offsets + * @return Map of TopicPartition to end offset + */ + public Map getEndOffsets( + Collection partitions, boolean skipCache) { + if (partitions == null || partitions.isEmpty()) { + return Collections.emptyMap(); + } + + Map result = new HashMap<>(); + List partitionsToFetch = new ArrayList<>(); + + // Check cache first if not skipping + if (!skipCache) { + for (TopicPartition partition : partitions) { + Long cached = endOffsetCache.getIfPresent(partition); + if (cached != null) { + result.put(partition, cached); + } else { + partitionsToFetch.add(partition); + } + } + + // If all partitions were cached, return early + if (partitionsToFetch.isEmpty()) { + return result; + } + } else { + // If skipping cache, fetch all partitions + partitionsToFetch.addAll(partitions); + // Invalidate cache entries + for (TopicPartition partition : partitions) { + endOffsetCache.invalidate(partition); + } + } + + // Fetch end offsets for partitions not in cache + try (Consumer consumer = consumerSupplier.get()) { + consumer.assign(partitionsToFetch); + Map fetchedOffsets = consumer.endOffsets(partitionsToFetch); + + // Update cache and results + for (Map.Entry entry : fetchedOffsets.entrySet()) { + endOffsetCache.put(entry.getKey(), entry.getValue()); + result.put(entry.getKey(), entry.getValue()); + } + } catch (Exception e) { + log.error("Error fetching end offsets for specific partitions", e); + } + + return result; + } + private Map tracePendingStatuses( Urn urn, Collection aspectNames, diff --git a/metadata-io/src/test/java/com/linkedin/metadata/search/ElasticSearchServiceTest.java b/metadata-io/src/test/java/com/linkedin/metadata/search/ElasticSearchServiceTest.java index 470a6ca27169d6..c1ebbe87b5f437 100644 --- a/metadata-io/src/test/java/com/linkedin/metadata/search/ElasticSearchServiceTest.java +++ b/metadata-io/src/test/java/com/linkedin/metadata/search/ElasticSearchServiceTest.java @@ -3,7 +3,10 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; import com.linkedin.common.urn.Urn; import com.linkedin.common.urn.UrnUtils; @@ -20,6 +23,7 @@ import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.Map; +import java.util.Set; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -123,4 +127,204 @@ public void testAppendRunId_NullRunId() { public void testAppendRunId_NullUrn() { testInstance.appendRunId(opContext, null, "test-run-id"); } + + @Test + public void testRaw_WithValidUrns() { + // Mock dependencies + ESSearchDAO mockEsSearchDAO = mock(ESSearchDAO.class); + EntityIndexBuilders indexBuilders = + new EntityIndexBuilders( + mock(ESIndexBuilder.class), + opContext.getEntityRegistry(), + opContext.getSearchContext().getIndexConvention(), + mock(SettingsBuilder.class)); + + // Create test instance with mocked ESSearchDAO + testInstance = + new ElasticSearchService( + indexBuilders, mockEsSearchDAO, mock(ESBrowseDAO.class), mockEsWriteDAO); + + // Create test data + Urn urn1 = UrnUtils.getUrn("urn:li:dataset:(urn:li:dataPlatform:snowflake,test_dataset1,PROD)"); + Urn urn2 = UrnUtils.getUrn("urn:li:dataset:(urn:li:dataPlatform:snowflake,test_dataset2,PROD)"); + Set urns = Set.of(urn1, urn2); + + // Create mock search hits for each URN + org.opensearch.search.SearchHit hit1 = mock(org.opensearch.search.SearchHit.class); + Map sourceMap1 = Map.of("field1", "value1", "field2", 123); + when(hit1.getSourceAsMap()).thenReturn(sourceMap1); + + org.opensearch.search.SearchHit hit2 = mock(org.opensearch.search.SearchHit.class); + Map sourceMap2 = Map.of("field1", "value2", "field3", true); + when(hit2.getSourceAsMap()).thenReturn(sourceMap2); + + // Create mock search results + org.opensearch.search.SearchHits searchHits1 = mock(org.opensearch.search.SearchHits.class); + when(searchHits1.getHits()).thenReturn(new org.opensearch.search.SearchHit[] {hit1}); + + org.opensearch.search.SearchHits searchHits2 = mock(org.opensearch.search.SearchHits.class); + when(searchHits2.getHits()).thenReturn(new org.opensearch.search.SearchHit[] {hit2}); + + org.opensearch.action.search.SearchResponse response1 = + mock(org.opensearch.action.search.SearchResponse.class); + when(response1.getHits()).thenReturn(searchHits1); + + org.opensearch.action.search.SearchResponse response2 = + mock(org.opensearch.action.search.SearchResponse.class); + when(response2.getHits()).thenReturn(searchHits2); + + // Mock the rawEntity response from ESSearchDAO + Map mockResponses = + Map.of( + urn1, response1, + urn2, response2); + when(mockEsSearchDAO.rawEntity(opContext, urns)).thenReturn(mockResponses); + + // Execute the method + Map> result = testInstance.raw(opContext, urns); + + // Verify the results + assertEquals(result.size(), 2); + assertEquals(result.get(urn1), sourceMap1); + assertEquals(result.get(urn2), sourceMap2); + + // Verify ESSearchDAO.rawEntity was called with the correct parameters + verify(mockEsSearchDAO).rawEntity(opContext, urns); + } + + @Test + public void testRaw_WithEmptyHits() { + // Mock dependencies + ESSearchDAO mockEsSearchDAO = mock(ESSearchDAO.class); + EntityIndexBuilders indexBuilders = + new EntityIndexBuilders( + mock(ESIndexBuilder.class), + opContext.getEntityRegistry(), + opContext.getSearchContext().getIndexConvention(), + mock(SettingsBuilder.class)); + + // Create test instance with mocked ESSearchDAO + testInstance = + new ElasticSearchService( + indexBuilders, mockEsSearchDAO, mock(ESBrowseDAO.class), mockEsWriteDAO); + + // Create test data + Urn urn1 = UrnUtils.getUrn("urn:li:dataset:(urn:li:dataPlatform:snowflake,test_dataset1,PROD)"); + Urn urn2 = UrnUtils.getUrn("urn:li:dataset:(urn:li:dataPlatform:snowflake,test_dataset2,PROD)"); + Set urns = Set.of(urn1, urn2); + + // Create search response with empty hits for the first URN + org.opensearch.search.SearchHits emptySearchHits = mock(org.opensearch.search.SearchHits.class); + when(emptySearchHits.getHits()).thenReturn(new org.opensearch.search.SearchHit[] {}); + + org.opensearch.action.search.SearchResponse emptyResponse = + mock(org.opensearch.action.search.SearchResponse.class); + when(emptyResponse.getHits()).thenReturn(emptySearchHits); + + // Create normal response for the second URN + org.opensearch.search.SearchHit hit = mock(org.opensearch.search.SearchHit.class); + Map sourceMap = Map.of("field1", "value", "field2", 456); + when(hit.getSourceAsMap()).thenReturn(sourceMap); + + org.opensearch.search.SearchHits searchHits = mock(org.opensearch.search.SearchHits.class); + when(searchHits.getHits()).thenReturn(new org.opensearch.search.SearchHit[] {hit}); + + org.opensearch.action.search.SearchResponse response = + mock(org.opensearch.action.search.SearchResponse.class); + when(response.getHits()).thenReturn(searchHits); + + // Mock the rawEntity response from ESSearchDAO + Map mockResponses = + Map.of( + urn1, emptyResponse, + urn2, response); + when(mockEsSearchDAO.rawEntity(opContext, urns)).thenReturn(mockResponses); + + // Execute the method + Map> result = testInstance.raw(opContext, urns); + + // Verify the results - should only have one entry for urn2 + assertEquals(result.size(), 1); + assertEquals(result.get(urn2), sourceMap); + assertFalse(result.containsKey(urn1)); + + // Verify ESSearchDAO.rawEntity was called with the correct parameters + verify(mockEsSearchDAO).rawEntity(opContext, urns); + } + + @Test + public void testRaw_WithNullHits() { + // Mock dependencies + ESSearchDAO mockEsSearchDAO = mock(ESSearchDAO.class); + EntityIndexBuilders indexBuilders = + new EntityIndexBuilders( + mock(ESIndexBuilder.class), + opContext.getEntityRegistry(), + opContext.getSearchContext().getIndexConvention(), + mock(SettingsBuilder.class)); + + // Create test instance with mocked ESSearchDAO + testInstance = + new ElasticSearchService( + indexBuilders, mockEsSearchDAO, mock(ESBrowseDAO.class), mockEsWriteDAO); + + // Create test data + Urn urn = UrnUtils.getUrn("urn:li:dataset:(urn:li:dataPlatform:snowflake,test_dataset1,PROD)"); + Set urns = Set.of(urn); + + // Create search response with null hits + org.opensearch.search.SearchHits nullSearchHits = mock(org.opensearch.search.SearchHits.class); + when(nullSearchHits.getHits()).thenReturn(null); + + org.opensearch.action.search.SearchResponse nullHitsResponse = + mock(org.opensearch.action.search.SearchResponse.class); + when(nullHitsResponse.getHits()).thenReturn(nullSearchHits); + + // Mock the rawEntity response from ESSearchDAO + Map mockResponses = + Map.of(urn, nullHitsResponse); + when(mockEsSearchDAO.rawEntity(opContext, urns)).thenReturn(mockResponses); + + // Execute the method + Map> result = testInstance.raw(opContext, urns); + + // Verify the results - should be empty since hits are null + assertTrue(result.isEmpty()); + + // Verify ESSearchDAO.rawEntity was called with the correct parameters + verify(mockEsSearchDAO).rawEntity(opContext, urns); + } + + @Test + public void testRaw_WithEmptyUrns() { + // Mock dependencies + ESSearchDAO mockEsSearchDAO = mock(ESSearchDAO.class); + EntityIndexBuilders indexBuilders = + new EntityIndexBuilders( + mock(ESIndexBuilder.class), + opContext.getEntityRegistry(), + opContext.getSearchContext().getIndexConvention(), + mock(SettingsBuilder.class)); + + // Create test instance with mocked ESSearchDAO + testInstance = + new ElasticSearchService( + indexBuilders, mockEsSearchDAO, mock(ESBrowseDAO.class), mockEsWriteDAO); + + // Create empty set of URNs + Set emptyUrns = Collections.emptySet(); + + // Mock the rawEntity response from ESSearchDAO + Map emptyResponses = Collections.emptyMap(); + when(mockEsSearchDAO.rawEntity(opContext, emptyUrns)).thenReturn(emptyResponses); + + // Execute the method + Map> result = testInstance.raw(opContext, emptyUrns); + + // Verify the results are empty + assertTrue(result.isEmpty()); + + // Verify ESSearchDAO.rawEntity was called with the correct parameters + verify(mockEsSearchDAO).rawEntity(opContext, emptyUrns); + } } diff --git a/metadata-io/src/test/java/com/linkedin/metadata/search/query/ESSearchDAORawEntityTest.java b/metadata-io/src/test/java/com/linkedin/metadata/search/query/ESSearchDAORawEntityTest.java new file mode 100644 index 00000000000000..ed08706bdf322e --- /dev/null +++ b/metadata-io/src/test/java/com/linkedin/metadata/search/query/ESSearchDAORawEntityTest.java @@ -0,0 +1,69 @@ +package com.linkedin.metadata.search.query; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +import com.linkedin.common.urn.Urn; +import com.linkedin.metadata.config.search.SearchConfiguration; +import com.linkedin.metadata.search.elasticsearch.query.ESSearchDAO; +import io.datahubproject.metadata.context.OperationContext; +import io.datahubproject.test.metadata.context.TestOperationContexts; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.mockito.Mockito; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.testng.annotations.Test; + +public class ESSearchDAORawEntityTest { + + @Test + public void testRawEntityWithMockedClient() throws Exception { + // Setup mocks + RestHighLevelClient mockClient = Mockito.mock(RestHighLevelClient.class); + OperationContext opContext = TestOperationContexts.systemContextNoValidate(); + + // Mock search response + SearchResponse mockResponse = Mockito.mock(SearchResponse.class); + SearchHits mockHits = Mockito.mock(SearchHits.class); + Mockito.when(mockHits.getHits()).thenReturn(new SearchHit[0]); + Mockito.when(mockResponse.getHits()).thenReturn(mockHits); + + // Setup behavior for mocks + Mockito.when(mockClient.search(Mockito.any(), Mockito.eq(RequestOptions.DEFAULT))) + .thenReturn(mockResponse); + + // Create test URN + Urn datasetUrn = + Urn.createFromString("urn:li:dataset:(urn:li:dataPlatform:test,test.table,PROD)"); + Set urns = new HashSet<>(); + urns.add(datasetUrn); + + // Create ESSearchDAO with mocked client + ESSearchDAO esSearchDAO = + new ESSearchDAO( + mockClient, + false, + "elasticsearch", + new SearchConfiguration(), + null, + com.linkedin.metadata.search.elasticsearch.query.filter.QueryFilterRewriteChain.EMPTY); + + // Execute rawEntity method + Map results = esSearchDAO.rawEntity(opContext, urns); + + // Verify results + assertNotNull(results); + assertEquals(results.size(), 1); + assertTrue(results.containsKey(datasetUrn)); + assertEquals(results.get(datasetUrn), mockResponse); + + // Verify the search was performed with correct parameters + Mockito.verify(mockClient).search(Mockito.any(), Mockito.eq(RequestOptions.DEFAULT)); + } +} diff --git a/metadata-io/src/test/java/com/linkedin/metadata/trace/BaseKafkaTraceReaderTest.java b/metadata-io/src/test/java/com/linkedin/metadata/trace/BaseKafkaTraceReaderTest.java index 06e40a4e142f9e..28a40a30718ac8 100644 --- a/metadata-io/src/test/java/com/linkedin/metadata/trace/BaseKafkaTraceReaderTest.java +++ b/metadata-io/src/test/java/com/linkedin/metadata/trace/BaseKafkaTraceReaderTest.java @@ -3,6 +3,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -19,6 +21,7 @@ import io.datahubproject.openapi.v1.models.TraceWriteStatus; import java.io.IOException; import java.time.Duration; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -236,4 +239,304 @@ public void testFindMessages() throws Exception { assertEquals(result.get(TEST_URN).get(ASPECT_NAME).getFirst(), mockRecord); assertEquals(result.get(TEST_URN).get(ASPECT_NAME).getSecond(), systemMetadata); } + + @Test + public void testGetAllPartitionOffsets_WithCache() { + // Arrange + Node mockNode = new Node(0, "localhost", 9092); + + // Setup multiple partitions for testing + TopicPartitionInfo partitionInfo0 = + new TopicPartitionInfo( + 0, mockNode, Collections.singletonList(mockNode), Collections.singletonList(mockNode)); + TopicPartitionInfo partitionInfo1 = + new TopicPartitionInfo( + 1, mockNode, Collections.singletonList(mockNode), Collections.singletonList(mockNode)); + + List partitionInfos = Arrays.asList(partitionInfo0, partitionInfo1); + TopicDescription topicDescription = new TopicDescription(TOPIC_NAME, false, partitionInfos); + + DescribeTopicsResult mockDescribeTopicsResult = mock(DescribeTopicsResult.class); + when(mockDescribeTopicsResult.all()) + .thenReturn( + KafkaFuture.completedFuture(Collections.singletonMap(TOPIC_NAME, topicDescription))); + when(adminClient.describeTopics(anyCollection())).thenReturn(mockDescribeTopicsResult); + + // Setup consumer group offsets for multiple partitions + TopicPartition topicPartition0 = new TopicPartition(TOPIC_NAME, 0); + TopicPartition topicPartition1 = new TopicPartition(TOPIC_NAME, 1); + + Map offsetMap = new HashMap<>(); + offsetMap.put(topicPartition0, new OffsetAndMetadata(100L)); + offsetMap.put(topicPartition1, new OffsetAndMetadata(200L)); + + ListConsumerGroupOffsetsResult mockOffsetResult = mock(ListConsumerGroupOffsetsResult.class); + when(adminClient.listConsumerGroupOffsets(CONSUMER_GROUP)).thenReturn(mockOffsetResult); + when(mockOffsetResult.partitionsToOffsetAndMetadata()) + .thenReturn(KafkaFuture.completedFuture(offsetMap)); + + // Act + Map result1 = traceReader.getAllPartitionOffsets(false); + + // Assert + assertEquals(result1.size(), 2); + assertEquals(result1.get(topicPartition0).offset(), 100L); + assertEquals(result1.get(topicPartition1).offset(), 200L); + + // Act again - this should use cache + Map result2 = traceReader.getAllPartitionOffsets(false); + + // Assert again + assertEquals(result2.size(), 2); + assertEquals(result2.get(topicPartition0).offset(), 100L); + assertEquals(result2.get(topicPartition1).offset(), 200L); + + // The implementation actually calls describeTopics for each getAllPartitionOffsets call + // This is because the topicPartitionCache in KafkaTraceReader doesn't cache the topic + // description + verify(adminClient, times(2)).describeTopics(anyCollection()); + } + + @Test + public void testGetAllPartitionOffsets_SkipCache() { + // Arrange + Node mockNode = new Node(0, "localhost", 9092); + TopicPartitionInfo partitionInfo = + new TopicPartitionInfo( + 0, mockNode, Collections.singletonList(mockNode), Collections.singletonList(mockNode)); + + TopicDescription topicDescription = + new TopicDescription(TOPIC_NAME, false, Collections.singletonList(partitionInfo)); + + DescribeTopicsResult mockDescribeTopicsResult = mock(DescribeTopicsResult.class); + when(mockDescribeTopicsResult.all()) + .thenReturn( + KafkaFuture.completedFuture(Collections.singletonMap(TOPIC_NAME, topicDescription))); + when(adminClient.describeTopics(anyCollection())).thenReturn(mockDescribeTopicsResult); + + TopicPartition topicPartition = new TopicPartition(TOPIC_NAME, 0); + + // First call returns offset 100 + ListConsumerGroupOffsetsResult mockOffsetResult1 = mock(ListConsumerGroupOffsetsResult.class); + when(adminClient.listConsumerGroupOffsets(CONSUMER_GROUP)) + .thenReturn(mockOffsetResult1) + .thenReturn(mockOffsetResult1); // Return same mock for second call + + Map offsetMap1 = new HashMap<>(); + offsetMap1.put(topicPartition, new OffsetAndMetadata(100L)); + + when(mockOffsetResult1.partitionsToOffsetAndMetadata()) + .thenReturn(KafkaFuture.completedFuture(offsetMap1)); + + // Act - first call should populate cache + Map result1 = traceReader.getAllPartitionOffsets(false); + + // Assert first result + assertEquals(result1.size(), 1); + assertEquals(result1.get(topicPartition).offset(), 100L); + + // Change the mock to return a different offset for the next call + ListConsumerGroupOffsetsResult mockOffsetResult2 = mock(ListConsumerGroupOffsetsResult.class); + when(adminClient.listConsumerGroupOffsets(CONSUMER_GROUP)).thenReturn(mockOffsetResult2); + + Map offsetMap2 = new HashMap<>(); + offsetMap2.put(topicPartition, new OffsetAndMetadata(200L)); + + when(mockOffsetResult2.partitionsToOffsetAndMetadata()) + .thenReturn(KafkaFuture.completedFuture(offsetMap2)); + + // Act - second call with skipCache=true should bypass cache + Map result2 = traceReader.getAllPartitionOffsets(true); + + // Assert second result + assertEquals(result2.size(), 1); + assertEquals(result2.get(topicPartition).offset(), 200L); + + // Verify that listConsumerGroupOffsets was called twice + verify(adminClient, times(2)).listConsumerGroupOffsets(CONSUMER_GROUP); + } + + @Test + public void testGetEndOffsets_WithCache() { + // Arrange + Node mockNode = new Node(0, "localhost", 9092); + TopicPartitionInfo partitionInfo = + new TopicPartitionInfo( + 0, mockNode, Collections.singletonList(mockNode), Collections.singletonList(mockNode)); + + TopicDescription topicDescription = + new TopicDescription(TOPIC_NAME, false, Collections.singletonList(partitionInfo)); + + DescribeTopicsResult mockDescribeTopicsResult = mock(DescribeTopicsResult.class); + when(mockDescribeTopicsResult.all()) + .thenReturn( + KafkaFuture.completedFuture(Collections.singletonMap(TOPIC_NAME, topicDescription))); + when(adminClient.describeTopics(anyCollection())).thenReturn(mockDescribeTopicsResult); + + // Setup consumer to return end offsets + TopicPartition topicPartition = new TopicPartition(TOPIC_NAME, 0); + Map endOffsets = Collections.singletonMap(topicPartition, 500L); + when(consumer.endOffsets(anyCollection())).thenReturn(endOffsets); + + // Act + Map result1 = traceReader.getEndOffsets(false); + + // Assert + assertEquals(result1.size(), 1); + assertEquals(result1.get(topicPartition).longValue(), 500L); + + // Act again - this should use cache + Map result2 = traceReader.getEndOffsets(false); + + // Assert again + assertEquals(result2.size(), 1); + assertEquals(result2.get(topicPartition).longValue(), 500L); + + // Verify that endOffsets was called only once + verify(consumer, times(1)).endOffsets(anyCollection()); + } + + @Test + public void testGetEndOffsets_SkipCache() { + // Arrange + Node mockNode = new Node(0, "localhost", 9092); + TopicPartitionInfo partitionInfo = + new TopicPartitionInfo( + 0, mockNode, Collections.singletonList(mockNode), Collections.singletonList(mockNode)); + + TopicDescription topicDescription = + new TopicDescription(TOPIC_NAME, false, Collections.singletonList(partitionInfo)); + + DescribeTopicsResult mockDescribeTopicsResult = mock(DescribeTopicsResult.class); + when(mockDescribeTopicsResult.all()) + .thenReturn( + KafkaFuture.completedFuture(Collections.singletonMap(TOPIC_NAME, topicDescription))); + when(adminClient.describeTopics(anyCollection())).thenReturn(mockDescribeTopicsResult); + + // Setup consumer to return end offsets + TopicPartition topicPartition = new TopicPartition(TOPIC_NAME, 0); + Map endOffsets1 = Collections.singletonMap(topicPartition, 500L); + Map endOffsets2 = Collections.singletonMap(topicPartition, 600L); + + when(consumer.endOffsets(anyCollection())).thenReturn(endOffsets1).thenReturn(endOffsets2); + + // Act + Map result1 = traceReader.getEndOffsets(false); + + // Assert + assertEquals(result1.size(), 1); + assertEquals(result1.get(topicPartition).longValue(), 500L); + + // Act again with skipCache=true + Map result2 = traceReader.getEndOffsets(true); + + // Assert again + assertEquals(result2.size(), 1); + assertEquals(result2.get(topicPartition).longValue(), 600L); + + // Verify that endOffsets was called twice + verify(consumer, times(2)).endOffsets(anyCollection()); + } + + @Test + public void testGetEndOffsets_SpecificPartitions() { + // Arrange + Node mockNode = new Node(0, "localhost", 9092); + TopicPartition topicPartition = new TopicPartition(TOPIC_NAME, 0); + TopicPartition topicPartition2 = new TopicPartition(TOPIC_NAME, 1); + List partitions = Arrays.asList(topicPartition, topicPartition2); + + // Setup consumer to return end offsets + Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 500L); + endOffsets.put(topicPartition2, 600L); + + when(consumer.endOffsets(anyCollection())).thenReturn(endOffsets); + + // Act + Map result = traceReader.getEndOffsets(partitions, false); + + // Assert + assertEquals(result.size(), 2); + assertEquals(result.get(topicPartition).longValue(), 500L); + assertEquals(result.get(topicPartition2).longValue(), 600L); + + // Verify that endOffsets was called once + verify(consumer, times(1)).endOffsets(anyCollection()); + + // Verify that assign was called with the correct partitions + verify(consumer, times(1)).assign(partitions); + } + + @Test + public void testGetEndOffsets_SpecificPartitions_WithCache() { + // Arrange + TopicPartition topicPartition = new TopicPartition(TOPIC_NAME, 0); + TopicPartition topicPartition2 = new TopicPartition(TOPIC_NAME, 1); + List partitions = Arrays.asList(topicPartition, topicPartition2); + + // Setup consumer to return end offsets + Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 500L); + endOffsets.put(topicPartition2, 600L); + + when(consumer.endOffsets(anyCollection())).thenReturn(endOffsets); + + // Act - first call to populate cache + Map result1 = traceReader.getEndOffsets(partitions, false); + + // Assert first result + assertEquals(result1.size(), 2); + assertEquals(result1.get(topicPartition).longValue(), 500L); + assertEquals(result1.get(topicPartition2).longValue(), 600L); + + // Act - second call should use cache + Map result2 = traceReader.getEndOffsets(partitions, false); + + // Assert second result + assertEquals(result2.size(), 2); + assertEquals(result2.get(topicPartition).longValue(), 500L); + assertEquals(result2.get(topicPartition2).longValue(), 600L); + + // Verify that endOffsets was called only once + verify(consumer, times(1)).endOffsets(anyCollection()); + } + + @Test + public void testGetEndOffsets_SpecificPartitions_SkipCache() { + // Arrange + TopicPartition topicPartition = new TopicPartition(TOPIC_NAME, 0); + TopicPartition topicPartition2 = new TopicPartition(TOPIC_NAME, 1); + List partitions = Arrays.asList(topicPartition, topicPartition2); + + // Setup consumer to return different end offsets on each call + Map endOffsets1 = new HashMap<>(); + endOffsets1.put(topicPartition, 500L); + endOffsets1.put(topicPartition2, 600L); + + Map endOffsets2 = new HashMap<>(); + endOffsets2.put(topicPartition, 700L); + endOffsets2.put(topicPartition2, 800L); + + when(consumer.endOffsets(anyCollection())).thenReturn(endOffsets1).thenReturn(endOffsets2); + + // Act - first call to populate cache + Map result1 = traceReader.getEndOffsets(partitions, false); + + // Assert first result + assertEquals(result1.size(), 2); + assertEquals(result1.get(topicPartition).longValue(), 500L); + assertEquals(result1.get(topicPartition2).longValue(), 600L); + + // Act - second call with skipCache=true should bypass cache + Map result2 = traceReader.getEndOffsets(partitions, true); + + // Assert second result + assertEquals(result2.size(), 2); + assertEquals(result2.get(topicPartition).longValue(), 700L); + assertEquals(result2.get(topicPartition2).longValue(), 800L); + + // Verify that endOffsets was called twice + verify(consumer, times(2)).endOffsets(anyCollection()); + } } diff --git a/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/elastic/OperationsController.java b/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/elastic/ElasticsearchController.java similarity index 90% rename from metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/elastic/OperationsController.java rename to metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/elastic/ElasticsearchController.java index 6b20b4e65a586b..da56036eba3421 100644 --- a/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/elastic/OperationsController.java +++ b/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/elastic/ElasticsearchController.java @@ -13,6 +13,7 @@ import com.deblock.jsondiff.viewer.PatchDiffViewer; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.linkedin.common.urn.Urn; import com.linkedin.common.urn.UrnUtils; import com.linkedin.metadata.authorization.PoliciesConfig; import com.linkedin.metadata.entity.EntityService; @@ -31,12 +32,16 @@ import io.datahubproject.openapi.util.ElasticsearchUtils; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.servlet.http.HttpServletRequest; import java.net.URISyntaxException; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -63,9 +68,9 @@ @RequestMapping("/openapi/operations/elasticSearch") @Slf4j @Tag( - name = "ElasticSearchOperations", + name = "ElasticSearch Operations", description = "An API for managing your elasticsearch instance") -public class OperationsController { +public class ElasticsearchController { private final AuthorizerChain authorizerChain; private final OperationContext systemOperationContext; private final SystemMetadataService systemMetadataService; @@ -74,7 +79,7 @@ public class OperationsController { private final EntityService entityService; private final ObjectMapper objectMapper; - public OperationsController( + public ElasticsearchController( OperationContext systemOperationContext, SystemMetadataService systemMetadataService, TimeseriesAspectService timeseriesAspectService, @@ -96,7 +101,6 @@ public void initBinder(WebDataBinder binder) { binder.registerCustomEditor(String[].class, new StringArrayPropertyEditor(null)); } - @Tag(name = "ElasticSearchOperations") @GetMapping(path = "/getTaskStatus", produces = MediaType.APPLICATION_JSON_VALUE) @Operation(summary = "Get Task Status") public ResponseEntity getTaskStatus(HttpServletRequest request, String task) { @@ -139,7 +143,6 @@ public ResponseEntity getTaskStatus(HttpServletRequest request, String t return ResponseEntity.ok(j.toString()); } - @Tag(name = "ElasticSearchOperations") @GetMapping(path = "/getIndexSizes", produces = MediaType.APPLICATION_JSON_VALUE) @Operation(summary = "Get Index Sizes") public ResponseEntity getIndexSizes(HttpServletRequest request) { @@ -177,7 +180,53 @@ public ResponseEntity getIndexSizes(HttpServletRequest request) { return ResponseEntity.ok(j.toString()); } - @Tag(name = "ElasticSearchOperations") + @PostMapping(path = "/entity/raw", produces = MediaType.APPLICATION_JSON_VALUE) + @Operation( + description = + "Retrieves raw Elasticsearch documents for the provided URNs. Requires MANAGE_SYSTEM_OPERATIONS_PRIVILEGE.", + responses = { + @ApiResponse( + responseCode = "200", + description = "Successfully retrieved raw documents", + content = @Content(mediaType = MediaType.APPLICATION_JSON_VALUE)), + @ApiResponse( + responseCode = "403", + description = "Caller not authorized to access raw documents"), + @ApiResponse(responseCode = "400", description = "Invalid URN format provided") + }) + public ResponseEntity>> getEntityRaw( + HttpServletRequest request, + @RequestBody + @Nonnull + @Schema( + description = "Set of URN strings to fetch raw documents for", + example = "[\"urn:li:dataset:(urn:li:dataPlatform:hive,SampleTable,PROD)\"]") + Set urnStrs) { + + Set urns = urnStrs.stream().map(UrnUtils::getUrn).collect(Collectors.toSet()); + + Authentication authentication = AuthenticationContext.getAuthentication(); + String actorUrnStr = authentication.getActor().toUrnStr(); + OperationContext opContext = + systemOperationContext.asSession( + RequestContext.builder() + .buildOpenapi( + actorUrnStr, + request, + "getRawEntity", + urns.stream().map(Urn::getEntityType).distinct().toList()), + authorizerChain, + authentication); + + if (!AuthUtil.isAPIOperationsAuthorized( + opContext, PoliciesConfig.MANAGE_SYSTEM_OPERATIONS_PRIVILEGE)) { + log.error("{} is not authorized to get raw ES documents", actorUrnStr); + return ResponseEntity.status(HttpStatus.FORBIDDEN).body(null); + } + + return ResponseEntity.ok(searchService.raw(opContext, urns)); + } + @GetMapping(path = "/explainSearchQuery", produces = MediaType.APPLICATION_JSON_VALUE) @Operation(summary = "Explain Search Query") public ResponseEntity explainSearchQuery( @@ -280,7 +329,6 @@ public ResponseEntity explainSearchQuery( return ResponseEntity.ok(response); } - @Tag(name = "ElasticSearchOperations") @GetMapping(path = "/explainSearchQueryDiff", produces = MediaType.TEXT_PLAIN_VALUE) @Operation(summary = "Explain the differences in scoring for 2 documents") public ResponseEntity explainSearchQueryDiff( diff --git a/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/kafka/KafkaController.java b/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/kafka/KafkaController.java new file mode 100644 index 00000000000000..f7c1ff2b207fce --- /dev/null +++ b/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/kafka/KafkaController.java @@ -0,0 +1,437 @@ +package io.datahubproject.openapi.operations.kafka; + +import com.datahub.authentication.Authentication; +import com.datahub.authentication.AuthenticationContext; +import com.datahub.authorization.AuthUtil; +import com.datahub.authorization.AuthorizerChain; +import com.linkedin.metadata.authorization.PoliciesConfig; +import com.linkedin.metadata.trace.MCLTraceReader; +import com.linkedin.metadata.trace.MCPTraceReader; +import io.datahubproject.metadata.context.OperationContext; +import io.datahubproject.metadata.context.RequestContext; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.servlet.http.HttpServletRequest; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +@RestController +@RequestMapping("/openapi/operations/kafka") +@Slf4j +public class KafkaController { + + private final OperationContext systemOperationContext; + private final AuthorizerChain authorizerChain; + private final MCPTraceReader mcpTraceReader; + private final MCLTraceReader mclTraceReader; + private final MCLTraceReader mclTimeseriesTraceReader; + + public KafkaController( + @Qualifier("systemOperationContext") OperationContext systemOperationContext, + AuthorizerChain authorizerChain, + MCPTraceReader mcpTraceReader, + @Qualifier("mclVersionedTraceReader") MCLTraceReader mclTraceReader, + @Qualifier("mclTimeseriesTraceReader") MCLTraceReader mclTimeseriesTraceReader) { + this.systemOperationContext = systemOperationContext; + this.authorizerChain = authorizerChain; + this.mcpTraceReader = mcpTraceReader; + this.mclTraceReader = mclTraceReader; + this.mclTimeseriesTraceReader = mclTimeseriesTraceReader; + } + + @Tag( + name = "Kafka Offsets", + description = "APIs for retrieving Kafka consumer offset information") + @GetMapping(path = "/mcp/consumer/offsets", produces = MediaType.APPLICATION_JSON_VALUE) + @Operation( + summary = "Get MetadataChangeProposal consumer kafka offsets with lag metrics", + description = + "Retrieves the current offsets and lag information for all partitions of the MCP topic from the consumer group", + responses = { + @ApiResponse( + responseCode = "200", + description = "Successfully retrieved consumer offsets and lag metrics", + content = + @Content( + mediaType = MediaType.APPLICATION_JSON_VALUE, + schema = @Schema(implementation = KafkaOffsetResponse.class))), + @ApiResponse( + responseCode = "403", + description = "Caller is not authorized to access this endpoint", + content = + @Content( + mediaType = MediaType.APPLICATION_JSON_VALUE, + schema = @Schema(implementation = ErrorResponse.class))) + }) + @Parameter( + name = "skipCache", + description = "Whether to bypass the offset cache and fetch fresh values directly from Kafka", + schema = @Schema(type = "boolean", defaultValue = "false"), + required = false) + @Parameter( + name = "detailed", + description = "Whether to include per-partition offset details in the response", + schema = @Schema(type = "boolean", defaultValue = "false"), + required = false) + public ResponseEntity getMCPOffsets( + HttpServletRequest httpServletRequest, + @RequestParam(value = "skipCache", defaultValue = "false") boolean skipCache, + @RequestParam(value = "detailed", defaultValue = "false") boolean detailed) { + Authentication authentication = AuthenticationContext.getAuthentication(); + String actorUrnStr = authentication.getActor().toUrnStr(); + + OperationContext opContext = + OperationContext.asSession( + systemOperationContext, + RequestContext.builder() + .buildOpenapi(actorUrnStr, httpServletRequest, "getMCPOffsets", List.of()), + authorizerChain, + authentication, + true); + + if (!AuthUtil.isAPIAuthorized(opContext, PoliciesConfig.MANAGE_SYSTEM_OPERATIONS_PRIVILEGE)) { + return ResponseEntity.status(HttpStatus.FORBIDDEN) + .body( + ErrorResponse.builder() + .error(actorUrnStr + " is not authorized to get kafka offsets") + .build()); + } + + // Get consumer offsets + Map offsetMap = + mcpTraceReader.getAllPartitionOffsets(skipCache); + + // Get end offsets for the same partitions to calculate lag + Map endOffsets = + mcpTraceReader.getEndOffsets(offsetMap.keySet(), skipCache); + + KafkaOffsetResponse response = + convertToResponse(mcpTraceReader.getConsumerGroupId(), offsetMap, endOffsets, detailed); + + return ResponseEntity.ok(response); + } + + @Tag( + name = "Kafka Offsets", + description = "APIs for retrieving Kafka consumer offset information") + @GetMapping(path = "/mcl/consumer/offsets", produces = MediaType.APPLICATION_JSON_VALUE) + @Operation( + summary = "Get MetadataChangeLog consumer kafka offsets with lag metrics", + description = + "Retrieves the current offsets and lag information for all partitions of the MCL topic from the consumer group", + responses = { + @ApiResponse( + responseCode = "200", + description = "Successfully retrieved consumer offsets and lag metrics", + content = + @Content( + mediaType = MediaType.APPLICATION_JSON_VALUE, + schema = @Schema(implementation = KafkaOffsetResponse.class))), + @ApiResponse( + responseCode = "403", + description = "Caller is not authorized to access this endpoint", + content = + @Content( + mediaType = MediaType.APPLICATION_JSON_VALUE, + schema = @Schema(implementation = ErrorResponse.class))) + }) + @Parameter( + name = "skipCache", + description = "Whether to bypass the offset cache and fetch fresh values directly from Kafka", + schema = @Schema(type = "boolean", defaultValue = "false"), + required = false) + @Parameter( + name = "detailed", + description = "Whether to include per-partition offset details in the response", + schema = @Schema(type = "boolean", defaultValue = "false"), + required = false) + public ResponseEntity getMCLOffsets( + HttpServletRequest httpServletRequest, + @RequestParam(value = "skipCache", defaultValue = "false") boolean skipCache, + @RequestParam(value = "detailed", defaultValue = "false") boolean detailed) { + Authentication authentication = AuthenticationContext.getAuthentication(); + String actorUrnStr = authentication.getActor().toUrnStr(); + + OperationContext opContext = + OperationContext.asSession( + systemOperationContext, + RequestContext.builder() + .buildOpenapi(actorUrnStr, httpServletRequest, "getMCLOffsets", List.of()), + authorizerChain, + authentication, + true); + + if (!AuthUtil.isAPIAuthorized(opContext, PoliciesConfig.MANAGE_SYSTEM_OPERATIONS_PRIVILEGE)) { + return ResponseEntity.status(HttpStatus.FORBIDDEN) + .body( + ErrorResponse.builder() + .error(actorUrnStr + " is not authorized to get kafka offsets") + .build()); + } + + // Get consumer offsets + Map offsetMap = + mclTraceReader.getAllPartitionOffsets(skipCache); + + // Get end offsets for the same partitions to calculate lag + Map endOffsets = + mclTraceReader.getEndOffsets(offsetMap.keySet(), skipCache); + + KafkaOffsetResponse response = + convertToResponse(mclTraceReader.getConsumerGroupId(), offsetMap, endOffsets, detailed); + + return ResponseEntity.ok(response); + } + + @Tag( + name = "Kafka Offsets", + description = "APIs for retrieving Kafka consumer offset information") + @GetMapping( + path = "/mcl-timeseries/consumer/offsets", + produces = MediaType.APPLICATION_JSON_VALUE) + @Operation( + summary = "Get MetadataChangeLog timeseries consumer kafka offsets with lag metrics", + description = + "Retrieves the current offsets and lag information for all partitions of the MCL timeseries topic from the consumer group", + responses = { + @ApiResponse( + responseCode = "200", + description = "Successfully retrieved consumer offsets and lag metrics", + content = + @Content( + mediaType = MediaType.APPLICATION_JSON_VALUE, + schema = @Schema(implementation = KafkaOffsetResponse.class))), + @ApiResponse( + responseCode = "403", + description = "Caller is not authorized to access this endpoint", + content = + @Content( + mediaType = MediaType.APPLICATION_JSON_VALUE, + schema = @Schema(implementation = ErrorResponse.class))) + }) + @Parameter( + name = "skipCache", + description = "Whether to bypass the offset cache and fetch fresh values directly from Kafka", + schema = @Schema(type = "boolean", defaultValue = "false"), + required = false) + @Parameter( + name = "detailed", + description = "Whether to include per-partition offset details in the response", + schema = @Schema(type = "boolean", defaultValue = "false"), + required = false) + public ResponseEntity getMCLTimeseriesOffsets( + HttpServletRequest httpServletRequest, + @RequestParam(value = "skipCache", defaultValue = "false") boolean skipCache, + @RequestParam(value = "detailed", defaultValue = "false") boolean detailed) { + Authentication authentication = AuthenticationContext.getAuthentication(); + String actorUrnStr = authentication.getActor().toUrnStr(); + + OperationContext opContext = + OperationContext.asSession( + systemOperationContext, + RequestContext.builder() + .buildOpenapi(actorUrnStr, httpServletRequest, "getMCLOffsets", List.of()), + authorizerChain, + authentication, + true); + + if (!AuthUtil.isAPIAuthorized(opContext, PoliciesConfig.MANAGE_SYSTEM_OPERATIONS_PRIVILEGE)) { + return ResponseEntity.status(HttpStatus.FORBIDDEN) + .body( + ErrorResponse.builder() + .error(actorUrnStr + " is not authorized to get kafka offsets") + .build()); + } + + // Get consumer offsets + Map offsetMap = + mclTimeseriesTraceReader.getAllPartitionOffsets(skipCache); + + // Get end offsets for the same partitions to calculate lag + Map endOffsets = + mclTimeseriesTraceReader.getEndOffsets(offsetMap.keySet(), skipCache); + + KafkaOffsetResponse response = + convertToResponse( + mclTimeseriesTraceReader.getConsumerGroupId(), offsetMap, endOffsets, detailed); + + return ResponseEntity.ok(response); + } + + /** + * Converts the Kafka offset data into a strongly-typed response object. + * + * @param consumerGroupId The consumer group ID + * @param offsetMap Map of TopicPartition to OffsetAndMetadata + * @param endOffsets Map of TopicPartition to end offset + * @param detailed Whether to include detailed partition information + * @return A structured KafkaOffsetResponse object + */ + private KafkaOffsetResponse convertToResponse( + String consumerGroupId, + Map offsetMap, + Map endOffsets, + boolean detailed) { + + // Early return if map is empty + if (offsetMap == null || offsetMap.isEmpty()) { + return new KafkaOffsetResponse(); + } + + // Group by topic + Map> topicToPartitions = + new HashMap<>(); + Map> topicToLags = new HashMap<>(); + + // Process each entry in the offset map + for (Map.Entry entry : offsetMap.entrySet()) { + TopicPartition tp = entry.getKey(); + OffsetAndMetadata offset = entry.getValue(); + + String topic = tp.topic(); + int partition = tp.partition(); + + // Calculate lag if we have end offset information + long consumerOffset = offset.offset(); + Long endOffset = endOffsets.get(tp); + Long lag = (endOffset != null) ? Math.max(0, endOffset - consumerOffset) : null; + + // Create partition info + KafkaOffsetResponse.PartitionInfo partitionInfo = + KafkaOffsetResponse.PartitionInfo.builder().offset(consumerOffset).lag(lag).build(); + + // Add metadata if present + if (offset.metadata() != null && !offset.metadata().isEmpty()) { + partitionInfo.setMetadata(offset.metadata()); + } + + // Store partition info by topic and partition ID + topicToPartitions.computeIfAbsent(topic, k -> new HashMap<>()).put(partition, partitionInfo); + + // Store lag for aggregate calculations + if (lag != null) { + topicToLags.computeIfAbsent(topic, k -> new ArrayList<>()).add(lag); + } + } + + // Create the response structure with sorted topics and partitions + Map topicMap = new LinkedHashMap<>(); + + // Process topics in sorted order + topicToPartitions.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .forEach( + topicEntry -> { + String topic = topicEntry.getKey(); + Map partitionMap = topicEntry.getValue(); + + // Create sorted map of partitions + Map sortedPartitions = + new LinkedHashMap<>(); + partitionMap.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .forEach(e -> sortedPartitions.put(String.valueOf(e.getKey()), e.getValue())); + + // Calculate metrics if we have lag information + KafkaOffsetResponse.LagMetrics metrics = null; + List lags = topicToLags.get(topic); + if (lags != null && !lags.isEmpty()) { + metrics = calculateLagMetrics(lags); + } + + // Create topic info + KafkaOffsetResponse.TopicOffsetInfo topicInfo = + KafkaOffsetResponse.TopicOffsetInfo.builder() + .partitions(detailed ? sortedPartitions : null) + .metrics(metrics) + .build(); + + topicMap.put(topic, topicInfo); + }); + + // Create map of consumer group ID to its topic information + KafkaOffsetResponse response = new KafkaOffsetResponse(); + response.put(consumerGroupId, topicMap); + return response; + } + + /** + * Calculates aggregate lag metrics from a list of lag values. + * + * @param lags List of lag values + * @return Structured lag metrics + */ + private KafkaOffsetResponse.LagMetrics calculateLagMetrics(List lags) { + if (lags == null || lags.isEmpty()) { + return null; + } + + // Sort the lags for median calculation + List sortedLags = new ArrayList<>(lags); + Collections.sort(sortedLags); + + // Calculate max lag + long maxLag = sortedLags.get(sortedLags.size() - 1); + + // Calculate median lag + long medianLag; + int middle = sortedLags.size() / 2; + if (sortedLags.size() % 2 == 0) { + // Even number of elements, average the middle two + medianLag = (sortedLags.get(middle - 1) + sortedLags.get(middle)) / 2; + } else { + // Odd number of elements, take the middle one + medianLag = sortedLags.get(middle); + } + + // Calculate total lag + long totalLag = 0; + for (Long lag : lags) { + totalLag += lag; + } + + // Calculate average lag + double avgLag = (double) totalLag / lags.size(); + + return KafkaOffsetResponse.LagMetrics.builder() + .maxLag(maxLag) + .medianLag(medianLag) + .totalLag(totalLag) + .avgLag(Math.round(avgLag)) + .build(); + } + + /** Simple error response class for auth failures. */ + @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor + @Schema(description = "Error response") + public static class ErrorResponse { + @Schema(description = "Error message") + private String error; + } +} diff --git a/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/kafka/KafkaOffsetResponse.java b/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/kafka/KafkaOffsetResponse.java new file mode 100644 index 00000000000000..edd8cf0241c294 --- /dev/null +++ b/metadata-service/openapi-servlet/src/main/java/io/datahubproject/openapi/operations/kafka/KafkaOffsetResponse.java @@ -0,0 +1,68 @@ +package io.datahubproject.openapi.operations.kafka; + +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.LinkedHashMap; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** Response model for Kafka consumer offsets API endpoint. */ +@Schema(description = "Kafka consumer group offset information with lag metrics") +public class KafkaOffsetResponse extends LinkedHashMap { + + /** Class representing information for a specific topic. */ + @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor + @Schema(description = "Information for a specific Kafka topic") + public static class TopicOffsetInfo { + + @Schema(description = "Map of partition ID to partition offset information") + private Map partitions; + + @Schema(description = "Aggregate metrics for this topic") + private LagMetrics metrics; + } + + /** Class representing information for a specific partition. */ + @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor + @Schema(description = "Information for a specific Kafka partition") + public static class PartitionInfo { + + @Schema(description = "Current consumer offset") + private Long offset; + + @Schema(description = "Additional metadata for this offset, if available") + private String metadata; + + @Schema(description = "Consumer lag (difference between end offset and consumer offset)") + private Long lag; + } + + /** Class representing aggregate lag metrics for a topic. */ + @Data + @Builder + @NoArgsConstructor + @AllArgsConstructor + @Schema(description = "Aggregated lag metrics across all partitions of a topic") + public static class LagMetrics { + + @Schema(description = "Maximum lag across all partitions") + private Long maxLag; + + @Schema(description = "Median lag across all partitions") + private Long medianLag; + + @Schema(description = "Total lag across all partitions") + private Long totalLag; + + @Schema(description = "Average lag across all partitions (rounded)") + private Long avgLag; + } +} diff --git a/metadata-service/services/src/main/java/com/linkedin/metadata/search/EntitySearchService.java b/metadata-service/services/src/main/java/com/linkedin/metadata/search/EntitySearchService.java index bbe96d49353514..837e3d17908afc 100644 --- a/metadata-service/services/src/main/java/com/linkedin/metadata/search/EntitySearchService.java +++ b/metadata-service/services/src/main/java/com/linkedin/metadata/search/EntitySearchService.java @@ -394,6 +394,16 @@ ExplainResponse explain( int size, @Nonnull List facets); + /** + * Fetch raw entity documents + * + * @param opContext operational context + * @param urns the document identifiers + * @return map of documents by urn + */ + @Nonnull + Map> raw(@Nonnull OperationContext opContext, @Nonnull Set urns); + /** * Return index convention * diff --git a/metadata-service/services/src/test/java/com/linkedin/metadata/service/search/EntitySearchServiceTest.java b/metadata-service/services/src/test/java/com/linkedin/metadata/service/search/EntitySearchServiceTest.java index 7bd6f3abe37136..1ee9acce080756 100644 --- a/metadata-service/services/src/test/java/com/linkedin/metadata/service/search/EntitySearchServiceTest.java +++ b/metadata-service/services/src/test/java/com/linkedin/metadata/service/search/EntitySearchServiceTest.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.commons.lang3.NotImplementedException; @@ -345,5 +346,11 @@ public ExplainResponse explain( public IndexConvention getIndexConvention() { return null; } + + @Override + public @Nonnull Map> raw( + @Nonnull OperationContext opContext, @Nonnull Set urns) { + return Map.of(); + } } }