Skip to content

Commit

Permalink
* sse: send ErrorResponse to client via "event: error" on exception
Browse files Browse the repository at this point in the history
Signed-off-by: neo <1100909+neowu@users.noreply.github.com>
  • Loading branch information
neowu committed Feb 26, 2025
1 parent 50eb785 commit bc1d147
Show file tree
Hide file tree
Showing 18 changed files with 97 additions and 40 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
## Change log

### 9.1.7 (2/26/2025 - )

* sse: send ErrorResponse to client via "event: error" on exception

### 9.1.6 (2/10/2025 - 2/25/2025)

* http_client: tweak sse checking
Expand Down
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ apply(plugin = "project")

subprojects {
group = "core.framework"
version = "9.1.6"
version = "9.1.7-b0"
}

val elasticVersion = "8.15.0"
Expand Down
14 changes: 10 additions & 4 deletions core-ng/src/main/java/core/framework/http/EventSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ public final class EventSource implements AutoCloseable, Iterable<EventSource.Ev
private int responseBodyLength;
private long elapsed;

private String lastId;
private String lastType; // for "event" field
private String lastId; // for "id" field
private Event nextEvent;

public EventSource(int statusCode, Map<String, String> headers, ResponseBody body, int requestBodyLength, long elapsed) {
Expand Down Expand Up @@ -69,11 +70,16 @@ private Event parseResponse(BufferedSource source) {
case "id":
lastId = line.substring(index + 2);
break;
case "event":
lastType = line.substring(index + 2);
break;
case "data":
String id = lastId;
lastId = null;
return new Event(id, line.substring(index + 2));
default: // ignore "event", "retry" and other fields
String type = lastType;
lastType = null;
return new Event(id, type, line.substring(index + 2));
default: // ignore "retry" and other fields
}
}
} catch (IOException e) {
Expand All @@ -83,7 +89,7 @@ private Event parseResponse(BufferedSource source) {
}
}

public record Event(String id, String data) {
public record Event(String id, String type, String data) {
}

private final class EventIterator implements Iterator<Event> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,7 @@ Object errorResponse(Throwable e, String userAgent, String actionId) {
}
return response;
} else {
var response = new ErrorResponse();
response.id = actionId;
response.message = e.getMessage();
if (e instanceof ErrorCode errorCode) {
response.errorCode = errorCode.errorCode();
} else {
response.errorCode = "INTERNAL_ERROR";
}
return response;
return ErrorResponse.errorResponse(e, actionId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import core.framework.internal.log.ActionLog;
import core.framework.internal.log.LogManager;
import core.framework.internal.log.Trace;
import core.framework.internal.web.bean.ResponseBeanWriter;
import core.framework.internal.web.controller.ControllerHolder;
import core.framework.internal.web.controller.InvocationImpl;
import core.framework.internal.web.controller.WebContextImpl;
Expand Down Expand Up @@ -42,8 +41,6 @@ public class HTTPHandler implements HttpHandler {
public final WebContextImpl webContext = new WebContextImpl();
public final HTTPErrorHandler errorHandler;

public final ResponseBeanWriter responseBeanWriter = new ResponseBeanWriter();

private final Logger logger = LoggerFactory.getLogger(HTTPHandler.class);
private final LogManager logManager;
private final SessionManager sessionManager;
Expand All @@ -58,7 +55,7 @@ public class HTTPHandler implements HttpHandler {
this.logManager = logManager;
this.sessionManager = sessionManager;
this.handlerContext = handlerContext;
responseHandler = new ResponseHandler(responseBeanWriter, templateManager, sessionManager);
responseHandler = new ResponseHandler(handlerContext.responseBeanWriter, templateManager, sessionManager);
errorHandler = new HTTPErrorHandler(responseHandler);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core.framework.internal.web;

import core.framework.internal.web.bean.RequestBeanReader;
import core.framework.internal.web.bean.ResponseBeanWriter;
import core.framework.internal.web.http.IPv4AccessControl;
import core.framework.internal.web.http.RateControl;
import core.framework.internal.web.request.RequestParser;
Expand All @@ -10,6 +11,7 @@
public class HTTPHandlerContext {
public final RequestParser requestParser = new RequestParser();
public final RequestBeanReader requestBeanReader = new RequestBeanReader();
public final ResponseBeanWriter responseBeanWriter = new ResponseBeanWriter();
public final RateControl rateControl = new RateControl();
@Nullable
public IPv4AccessControl accessControl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
/**
* @author neo
*/
final class ResponseHandlerContext {
public final class ResponseHandlerContext {
final ResponseBeanWriter writer;
final TemplateManager templateManager;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
package core.framework.internal.web.service;

import core.framework.api.json.Property;
import core.framework.log.ErrorCode;

/**
* @author neo
*/
public class ErrorResponse {
public final class ErrorResponse {
public static ErrorResponse errorResponse(Throwable e, String actionId) {
var response = new ErrorResponse();
response.id = actionId;
response.message = e.getMessage();
if (e instanceof ErrorCode errorCode) {
response.errorCode = errorCode.errorCode();
} else {
response.errorCode = "INTERNAL_ERROR";
}
return response;
}

@Property(name = "id")
public String id;

@Property(name = "errorCode")
public String errorCode;

@Property(name = "message")
public String message;
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package core.framework.internal.web.sse;

import core.framework.internal.log.filter.BytesLogParam;
import core.framework.log.ActionLogContext;
import core.framework.util.Sets;
import core.framework.util.StopWatch;
Expand Down Expand Up @@ -53,22 +54,22 @@ class ChannelImpl<T> implements java.nio.channels.Channel, Channel<T> {
@Override
public boolean send(String id, T event) {
String data = builder.build(id, event);
return send(data);
return sendBytes(Strings.bytes(data));
}

boolean send(String data) {
boolean sendBytes(byte[] data) {
if (closed) return false;

var watch = new StopWatch();
try {
queue.add(Strings.bytes(data));
queue.add(data);
lastSentTime = System.nanoTime();
sink.getIoThread().execute(() -> writeListener.handleEvent(sink));
return true;
} finally {
long elapsed = watch.elapsed();
ActionLogContext.track("sse", elapsed, 0, data.length());
LOGGER.debug("send sse data, channel={}, data={}, elapsed={}", id, data, elapsed); // message is not in json format, not masked, assume sse won't send any sensitive data
ActionLogContext.track("sse", elapsed, 0, data.length);
LOGGER.debug("send sse data, channel={}, data={}, elapsed={}", id, new BytesLogParam(data), elapsed); // message is not in json format, not masked, assume sse won't send any sensitive data
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package core.framework.internal.web.sse;

import core.framework.util.Strings;
import core.framework.web.sse.Channel;
import core.framework.web.sse.ServerSentEventContext;
import org.slf4j.Logger;
Expand Down Expand Up @@ -67,7 +68,7 @@ public void keepAlive() {
for (Channel<T> channel : channels.values()) {
ChannelImpl<?> impl = (ChannelImpl<?>) channel;
if (now - impl.lastSentTime >= 15_000_000_000L) {
impl.send(":\n");
impl.sendBytes(Strings.bytes(":\n"));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import core.framework.internal.log.LogManager;
import core.framework.internal.web.HTTPHandlerContext;
import core.framework.internal.web.request.RequestImpl;
import core.framework.internal.web.service.ErrorResponse;
import core.framework.internal.web.session.ReadOnlySession;
import core.framework.internal.web.session.SessionManager;
import core.framework.module.ServerSentEventConfig;
Expand All @@ -23,6 +24,7 @@
import org.xnio.channels.StreamSinkChannel;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -78,6 +80,7 @@ void handle(HttpServerExchange exchange, StreamSinkChannel sink) {
long httpDelay = System.nanoTime() - exchange.getRequestStartTime();
ActionLog actionLog = logManager.begin("=== sse connect begin ===", null);
var request = new RequestImpl(exchange, handlerContext.requestBeanReader);
ChannelImpl<Object> channel = null;
try {
logger.debug("httpDelay={}", httpDelay);
actionLog.stats.put("http_delay", (double) httpDelay);
Expand All @@ -92,13 +95,13 @@ void handle(HttpServerExchange exchange, StreamSinkChannel sink) {
actionLog.action("sse:" + path + ":connect");
handlerContext.rateControl.validateRate(ServerSentEventConfig.SSE_CONNECT_GROUP, request.clientIP());

var channel = new ChannelImpl<>(exchange, sink, support.context, support.builder, actionLog.id);
channel = new ChannelImpl<>(exchange, sink, support.context, support.builder, actionLog.id);
actionLog.context("channel", channel.id);
sink.getWriteSetter().set(channel.writeListener);
support.context.add(channel);
exchange.addExchangeCompleteListener(new ServerSentEventCloseHandler<>(logManager, channel, support.context));

channel.send("retry: 5000\n\n"); // set browser retry to 5s
channel.sendBytes(Strings.bytes("retry: 5000\n\n")); // set browser retry to 5s

request.session = ReadOnlySession.of(sessionManager.load(request, actionLog));
String lastEventId = exchange.getRequestHeaders().getLast(LAST_EVENT_ID);
Expand All @@ -107,13 +110,26 @@ void handle(HttpServerExchange exchange, StreamSinkChannel sink) {
if (!channel.groups.isEmpty()) actionLog.context("group", channel.groups.toArray()); // may join group onConnect
} catch (Throwable e) {
logManager.logError(e);
exchange.endExchange();

if (channel != null) {
byte[] error = errorResponse(handlerContext.responseBeanWriter.toJSON(ErrorResponse.errorResponse(e, actionLog.id)));
channel.sendBytes(error);
channel.close(); // gracefully shutdown connection to make sure retry/error can be sent
}
} finally {
logManager.end("=== sse connect end ===");
VirtualThread.COUNT.decrease();
}
}

byte[] errorResponse(byte[] errorResponse) {
ByteBuffer buffer = ByteBuffer.wrap(new byte[errorResponse.length + 38]);
buffer.put(Strings.bytes("retry: 86400000\n\nevent: error\ndata: ")); // tell browser retry in 24 hours
buffer.put(errorResponse);
buffer.put(Strings.bytes("\n\n"));
return buffer.array();
}

public <T> void add(HTTPMethod method, String path, Class<T> eventClass, ChannelListener<T> listener, ServerSentEventContextImpl<T> context) {
var previous = supports.put(key(method.name(), path), new ChannelSupport<>(listener, eventClass, context));
if (previous != null) throw new Error(Strings.format("found duplicate sse listener, method={}, path={}", method, path));
Expand Down
2 changes: 1 addition & 1 deletion core-ng/src/main/java/core/framework/module/APIConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public <T> void service(Class<T> serviceInterface, T service) {
logger.info("create web service, interface={}", serviceInterface.getCanonicalName());
var validator = new WebServiceInterfaceValidator(serviceInterface, context.beanClassValidator);
validator.requestBeanReader = context.httpServer.handlerContext.requestBeanReader;
validator.responseBeanWriter = context.httpServer.handler.responseBeanWriter;
validator.responseBeanWriter = context.httpServer.handlerContext.responseBeanWriter;
validator.validate();
new WebServiceImplValidator<>(serviceInterface, service).validate();
new InjectValidator(service).validate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private void registerBean(Class<?> beanClass) {
}
reader.registerQueryParam(beanClass, context.beanClassValidator.beanClassNameValidator);
} else {
ResponseBeanWriter writer = context.httpServer.handler.responseBeanWriter;
ResponseBeanWriter writer = context.httpServer.handlerContext.responseBeanWriter;
if (reader.containsBean(beanClass) || writer.contains(beanClass)) {
throw new Error("bean class is already registered or referred by service interface, class=" + beanClass.getCanonicalName());
}
Expand Down
4 changes: 2 additions & 2 deletions core-ng/src/main/java/core/framework/web/sse/Channel.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ public interface Channel<T> {
// return true if event is queued, return false if channel is closed
boolean send(String id, T event);

default void send(T event) {
send(null, event);
default boolean send(T event) {
return send(null, event);
}

// gracefully close, queue "end exchange" into io thread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class HTTPHandlerTest {

@BeforeEach
void createHTTPServerHandler() {
handler = new HTTPHandler(null, null, null, null);
handler = new HTTPHandler(null, null, null, new HTTPHandlerContext());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package core.framework.internal.web.sse;

import core.framework.util.Strings;
import core.framework.web.sse.Channel;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -71,9 +72,9 @@ void keepAlive() {
context.keepAlive();

channel.lastSentTime = 0;
doReturn(Boolean.TRUE).when(channel).send(":\n");
doReturn(Boolean.TRUE).when(channel).sendBytes(Strings.bytes(":\n"));
context.keepAlive();
verify(channel, Mockito.times(1)).send(":\n");
verify(channel, Mockito.times(1)).sendBytes(Strings.bytes(":\n"));
}

private ChannelImpl<TestEvent> channel() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package core.framework.internal.web.sse;

import core.framework.util.Strings;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.assertj.core.api.Assertions.assertThat;

class ServerSentEventHandlerTest {
private ServerSentEventHandler handler;

@BeforeEach
void createServerSentEventHandler() {
handler = new ServerSentEventHandler(null, null, null);
}

@Test
void errorResponse() {
byte[] error = handler.errorResponse(Strings.bytes("{\"error_code\": \"NOT_FOUND\"}"));
assertThat(error).asString().isEqualTo("""
retry: 86400000
event: error
data: {"error_code": "NOT_FOUND"}
""");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,4 @@ public class ActionLogEntry {
public Map<String, Double> stats;
@Property(name = "perf_stats")
public Map<String, PerformanceStatMessage> performanceStats;
@Property(name = "trace_log_path")
public String traceLogPath;
}

0 comments on commit bc1d147

Please sign in to comment.