Skip to content

Commit

Permalink
Fix JsonGenerationException error in Local Sample Calculator and Anom…
Browse files Browse the repository at this point in the history
…aly Localization Execution Response (opensearch-project#3434)

* Add root object wrapper in LocalSampleCalculatorOutput and AnomalyLocalizationOutput.

Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>

* Fix format violations.

Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>

* Modify import * in RestMLExecuteActionTests, remove root object wrapper in AnomalyLocalizationOutputTests and LocalSampleCalculatorOutputTest.

Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>

* Removed unused variable in testAnomalyLocalizationExecutionResponse function.

Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>

---------

Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>
  • Loading branch information
nathaliellenaa committed Jan 28, 2025
1 parent 17251cd commit 78eb74b
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
@Override
@SneakyThrows
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) {
builder.startObject();
builder.startArray(FIELD_RESULTS);
for (Map.Entry<String, Result> entry : this.results.entrySet()) {
builder.startObject();
Expand All @@ -196,6 +197,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
builder.endObject();
}
builder.endArray();
builder.endObject();
return builder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (result != null) {
builder.field("result", result);
}
builder.endObject();
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ public void testWriteable() throws Exception {
@Test
public void testXContent() throws Exception {
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder = output.toXContent(builder, null);
builder.endObject();
String json = builder.toString();
XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, null, json);
AnomalyLocalizationOutput newOutput = AnomalyLocalizationOutput.parse(parser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,17 @@ public void setUp() {

@Test
public void toXContent() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
output.toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject();
String jsonStr = builder.toString();
assertEquals("{\"result\":1.0}", jsonStr);
}

@Test
public void toXContent_EmptyOutput() throws IOException {
LocalSampleCalculatorOutput output = LocalSampleCalculatorOutput.builder().build();
XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
output.toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject();
String jsonStr = builder.toString();
assertEquals("{}", jsonStr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.utils.TestHelper.getAnomalyLocalizationRestRequest;
import static org.opensearch.ml.utils.TestHelper.getExecuteAgentRestRequest;
import static org.opensearch.ml.utils.TestHelper.getLocalSampleCalculatorRestRequest;
import static org.opensearch.ml.utils.TestHelper.getMetricsCorrelationRestRequest;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -32,8 +36,13 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput;
import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput.Bucket;
import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput.Result;
import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
Expand Down Expand Up @@ -337,4 +346,79 @@ public void testAgentExecutionResponsePlainText() throws Exception {
"{\"error\":{\"reason\":\"Invalid Request\",\"details\":\"Illegal Argument Exception\",\"type\":\"IllegalArgumentException\"},\"status\":400}";
assertEquals(expectedError, response.content().utf8ToString());
}

public void testLocalSampleCalculatorExecutionResponse() throws Exception {
RestRequest request = getLocalSampleCalculatorRestRequest();
XContentBuilder builder = XContentFactory.jsonBuilder();
when(channel.newBuilder()).thenReturn(builder);
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
LocalSampleCalculatorOutput output = LocalSampleCalculatorOutput.builder().totalSum(3.0).build();
MLExecuteTaskResponse response = MLExecuteTaskResponse
.builder()
.output(output)
.functionName(FunctionName.LOCAL_SAMPLE_CALCULATOR)
.build();
actionListener.onResponse(response);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<RestResponse> responseCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel).sendResponse(responseCaptor.capture());
BytesRestResponse response = (BytesRestResponse) responseCaptor.getValue();
assertEquals(RestStatus.OK, response.status());
assertEquals("{\"result\":3.0}", response.content().utf8ToString());
}

public void testAnomalyLocalizationExecutionResponse() throws Exception {
RestRequest request = getAnomalyLocalizationRestRequest();
XContentBuilder builder = XContentFactory.jsonBuilder();
when(channel.newBuilder()).thenReturn(builder);
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);

Bucket bucket1 = new Bucket();
bucket1.setStartTime(1620630000000L);
bucket1.setEndTime(1620716400000L);
bucket1.setOverallAggValue(65.0);

Result result = new Result();
result.setBuckets(Arrays.asList(bucket1));

AnomalyLocalizationOutput output = new AnomalyLocalizationOutput();
Map<String, Result> results = new HashMap<>();
results.put("sum", result);
output.setResults(results);

MLExecuteTaskResponse response = MLExecuteTaskResponse
.builder()
.output(output)
.functionName(FunctionName.ANOMALY_LOCALIZATION)
.build();
actionListener.onResponse(response);
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
doNothing().when(channel).sendResponse(any());
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<RestResponse> responseCaptor = ArgumentCaptor.forClass(RestResponse.class);
verify(channel).sendResponse(responseCaptor.capture());
BytesRestResponse response = (BytesRestResponse) responseCaptor.getValue();
assertEquals(RestStatus.OK, response.status());
String expectedJson = "{\"results\":[{"
+ "\"name\":\"sum\","
+ "\"result\":{"
+ "\"buckets\":["
+ "{"
+ "\"start_time\":1620630000000,"
+ "\"end_time\":1620716400000,"
+ "\"overall_aggregate_value\":65.0"
+ "}"
+ "]"
+ "}"
+ "}]}";
assertEquals(expectedJson, response.content().utf8ToString());
}
}
23 changes: 23 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.Constants;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.anomalylocalization.AnomalyLocalizationInput;
import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput;
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
Expand Down Expand Up @@ -360,6 +361,27 @@ public static RestRequest getMetricsCorrelationRestRequest() {
.build();
}

public static RestRequest getAnomalyLocalizationRestRequest() {
Map<String, String> params = new HashMap<>();
params.put(PARAMETER_ALGORITHM, FunctionName.ANOMALY_LOCALIZATION.name());
final String requestContent = "{"
+ "\"input_data\": {"
+ "\"index_name\": \"test-index\","
+ "\"attribute_field_names\": [\"attribute\"],"
+ "\"time_field_name\": \"timestamp\","
+ "\"start_time\": 1620630000000,"
+ "\"end_time\": 1621234800000,"
+ "\"min_time_interval\": 86400000,"
+ "\"num_outputs\": 1"
+ "}"
+ "}";
RestRequest request = new FakeRestRequest.Builder(getXContentRegistry())
.withParams(params)
.withContent(new BytesArray(requestContent), XContentType.JSON)
.build();
return request;
}

public static RestRequest getExecuteAgentRestRequest() {
Map<String, String> params = new HashMap<>();
params.put(PARAMETER_AGENT_ID, "test_agent_id");
Expand Down Expand Up @@ -407,6 +429,7 @@ private static NamedXContentRegistry getXContentRegistry() {
entries.add(KMeansParams.XCONTENT_REGISTRY);
entries.add(LocalSampleCalculatorInput.XCONTENT_REGISTRY);
entries.add(MetricsCorrelationInput.XCONTENT_REGISTRY);
entries.add(AnomalyLocalizationInput.XCONTENT_REGISTRY_ENTRY);
return new NamedXContentRegistry(entries);
}

Expand Down

0 comments on commit 78eb74b

Please sign in to comment.