Skip to content

Commit

Permalink
Add root object wrapper in LocalSampleCalculatorOutput and AnomalyLoc…
Browse files Browse the repository at this point in the history
…alizationOutput.

Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>
  • Loading branch information
nathaliellenaa committed Jan 27, 2025
1 parent 570edaf commit 8bedc83
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
@Override
@SneakyThrows
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) {
builder.startObject();
builder.startArray(FIELD_RESULTS);
for (Map.Entry<String, Result> entry : this.results.entrySet()) {
builder.startObject();
Expand All @@ -196,6 +197,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
builder.endObject();
}
builder.endArray();
builder.endObject();
return builder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (result != null) {
builder.field("result", result);
}
builder.endObject();
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import static org.opensearch.ml.utils.TestHelper.getExecuteAgentRestRequest;
import static org.opensearch.ml.utils.TestHelper.getLocalSampleCalculatorRestRequest;
import static org.opensearch.ml.utils.TestHelper.getMetricsCorrelationRestRequest;
import static org.opensearch.ml.utils.TestHelper.getAnomalyLocalizationRestRequest;

import java.io.IOException;
import java.util.List;
import java.util.*;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -29,6 +30,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
Expand All @@ -47,6 +49,11 @@
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.RemoteTransportException;
import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput;
import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput;
import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput.Entity;
import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput.Bucket;
import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput.Result;

public class RestMLExecuteActionTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -337,4 +344,86 @@ public void testAgentExecutionResponsePlainText() throws Exception {
"{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
assertEquals(expectedError, response.content().utf8ToString());
}
}

public void testLocalSampleCalculatorExecutionResponse() throws Exception {
RestRequest request = getLocalSampleCalculatorRestRequest();
XContentBuilder builder = XContentFactory.jsonBuilder();
when(channel.newBuilder()).thenReturn(builder);
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
LocalSampleCalculatorOutput output = LocalSampleCalculatorOutput.builder()
.totalSum(3.0)
.build();
MLExecuteTaskResponse response = MLExecuteTaskResponse.builder()
.output(output)
.functionName(FunctionName.LOCAL_SAMPLE_CALCULATOR)
.build();
actionListener.onResponse(response);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<RestResponse> responseCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel).sendResponse(responseCaptor.capture());
BytesRestResponse response = (BytesRestResponse) responseCaptor.getValue();
assertEquals(RestStatus.OK, response.status());
assertEquals("{\"result\":3.0}", response.content().utf8ToString());
}

public void testAnomalyLocalizationExecutionResponse() throws Exception {
RestRequest request = getAnomalyLocalizationRestRequest();
XContentBuilder builder = XContentFactory.jsonBuilder();
when(channel.newBuilder()).thenReturn(builder);
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);

Entity entity1 = new Entity();
entity1.setKey(Collections.singletonList("attr0"));
entity1.setContributionValue(1.0);
entity1.setBaseValue(2.0);
entity1.setNewValue(3.0);

Bucket bucket1 = new Bucket();
bucket1.setStartTime(1620630000000L);
bucket1.setEndTime(1620716400000L);
bucket1.setOverallAggValue(65.0);

Result result = new Result();
result.setBuckets(Arrays.asList(bucket1));

AnomalyLocalizationOutput output = new AnomalyLocalizationOutput();
Map<String, Result> results = new HashMap<>();
results.put("sum", result);
output.setResults(results);

MLExecuteTaskResponse response = MLExecuteTaskResponse.builder()
.output(output)
.functionName(FunctionName.ANOMALY_LOCALIZATION)
.build();
actionListener.onResponse(response);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<RestResponse> responseCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel).sendResponse(responseCaptor.capture());
BytesRestResponse response = (BytesRestResponse) responseCaptor.getValue();
assertEquals(RestStatus.OK, response.status());
String expectedJson =
"{\"results\":[{" +
"\"name\":\"sum\"," +
"\"result\":{" +
"\"buckets\":[" +
"{" +
"\"start_time\":1620630000000," +
"\"end_time\":1620716400000," +
"\"overall_aggregate_value\":65.0" +
"}" +
"]" +
"}" +
"}]}";
assertEquals(expectedJson, response.content().utf8ToString());
}
}
24 changes: 24 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.Constants;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.anomalylocalization.AnomalyLocalizationInput;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
Expand Down Expand Up @@ -358,6 +359,28 @@ public static RestRequest getMetricsCorrelationRestRequest() {
.build();
}

public static RestRequest getAnomalyLocalizationRestRequest() {
Map<String, String> params = new HashMap<>();
params.put(PARAMETER_ALGORITHM, FunctionName.ANOMALY_LOCALIZATION.name());
final String requestContent =
"{" +
"\"input_data\": {" +
"\"index_name\": \"test-index\"," +
"\"attribute_field_names\": [\"attribute\"]," +
"\"time_field_name\": \"timestamp\"," +
"\"start_time\": 1620630000000," +
"\"end_time\": 1621234800000," +
"\"min_time_interval\": 86400000," +
"\"num_outputs\": 1" +
"}" +
"}";
RestRequest request = new FakeRestRequest.Builder(getXContentRegistry())
.withParams(params)
.withContent(new BytesArray(requestContent), XContentType.JSON)
.build();
return request;
}

public static RestRequest getExecuteAgentRestRequest() {
Map<String, String> params = new HashMap<>();
params.put(PARAMETER_AGENT_ID, "test_agent_id");
Expand Down Expand Up @@ -405,6 +428,7 @@ private static NamedXContentRegistry getXContentRegistry() {
entries.add(KMeansParams.XCONTENT_REGISTRY);
entries.add(LocalSampleCalculatorInput.XCONTENT_REGISTRY);
entries.add(MetricsCorrelationInput.XCONTENT_REGISTRY);
entries.add(AnomalyLocalizationInput.XCONTENT_REGISTRY_ENTRY);
return new NamedXContentRegistry(entries);
}

Expand Down

0 comments on commit 8bedc83

Please sign in to comment.