Skip to content

Commit

Permalink
Merge pull request #373 from taosdata/enh/xftan/TD-33441-3.0
Browse files Browse the repository at this point in the history
enh: handle error when the polling result is null
  • Loading branch information
zitsen authored Jan 16, 2025
2 parents e4709e8 + a6c44ee commit 169e50e
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 22 deletions.
17 changes: 11 additions & 6 deletions controller/ws/tmq/tmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -733,14 +733,15 @@ func (t *TMQ) poll(ctx context.Context, session *melody.Session, req *TMQPollReq
}
t.nextTime = now.Add(t.autocommitInterval)
}
message, closed := t.wrapperPoll(logger, isDebug, req.BlockingTime)
pollResult, closed := t.wrapperPoll(logger, isDebug, req.BlockingTime)
if closed {
logger.Trace("server closed")
}
resp := &TMQPollResp{
Action: action,
ReqID: req.ReqID,
}
message := pollResult.Res
if message != nil {
messageType := wrapper.TMQGetResType(message)
if messageTypeIsValid(messageType) {
Expand Down Expand Up @@ -781,7 +782,11 @@ func (t *TMQ) poll(ctx context.Context, session *melody.Session, req *TMQPollReq
return
}
}

if pollResult.Code != 0 {
logger.Errorf("tmq poll error, code:%d, msg:%s", pollResult.Code, pollResult.ErrStr)
wsTMQErrorMsg(ctx, session, logger, int(pollResult.Code), pollResult.ErrStr, action, req.ReqID, nil)
return
}
resp.Timing = wstool.GetDuration(ctx)
wstool.WSWriteJson(session, logger, resp)
}
Expand Down Expand Up @@ -1641,7 +1646,7 @@ func (t *TMQ) wrapperCommit(logger *logrus.Entry, isDebug bool) (int32, bool) {
return errCode, false
}

func (t *TMQ) wrapperPoll(logger *logrus.Entry, isDebug bool, blockingTime int64) (unsafe.Pointer, bool) {
func (t *TMQ) wrapperPoll(logger *logrus.Entry, isDebug bool, blockingTime int64) (*tmqhandle.PollResult, bool) {
logger.Tracef("call tmq_poll, consumer:%p", t.consumer)
s := log.GetLogNow(isDebug)
t.asyncLocker.Lock()
Expand All @@ -1652,10 +1657,10 @@ func (t *TMQ) wrapperPoll(logger *logrus.Entry, isDebug bool, blockingTime int64
logger.Debugf("tmq_poll get async lock cost:%s", log.GetLogDuration(isDebug, s))
s = log.GetLogNow(isDebug)
asynctmq.TaosaTMQPollA(t.thread, t.consumer, blockingTime, t.handler.Handler)
message := <-t.handler.Caller.PollResult
result := <-t.handler.Caller.PollResult
t.asyncLocker.Unlock()
logger.Debugf("tmq_poll finish, res:%p, cost:%s", message, log.GetLogDuration(isDebug, s))
return message, false
logger.Debugf("tmq_poll finish, res:%p, code:%d, errmsg:%s, cost:%s", result.Res, result.Code, result.ErrStr, log.GetLogDuration(isDebug, s))
return result, false
}

func (t *TMQ) wrapperFreeResult(logger *logrus.Entry, isDebug bool) bool {
Expand Down
77 changes: 77 additions & 0 deletions controller/ws/tmq/tmq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3328,3 +3328,80 @@ func TestWrongPass(t *testing.T) {
assert.NoError(t, err)
assert.NotEqual(t, 0, subscribeResp.Code, subscribeResp.Message)
}

func TestPollError(t *testing.T) {
dbName := "test_ws_tmq_poll_error"
topic := "test_ws_tmq_poll_error_topic"

before(t, dbName, topic)

s := httptest.NewServer(router)
defer s.Close()
ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http")+"/rest/tmq", nil)
if err != nil {
t.Error(err)
return
}
defer func() {
err = ws.Close()
assert.NoError(t, err)
}()

defer func() {
err = after(ws, dbName, topic)
assert.NoError(t, err)
}()

// subscribe
b, _ := json.Marshal(TMQSubscribeReq{
User: "root",
Password: "taosdata",
DB: dbName,
GroupID: "test",
Topics: []string{topic},
AutoCommit: "false",
OffsetReset: "earliest",
SessionTimeoutMS: "10000",
MaxPollIntervalMS: "1000",
})
msg, err := doWebSocket(ws, TMQSubscribe, b)
assert.NoError(t, err)
var subscribeResp TMQSubscribeResp
err = json.Unmarshal(msg, &subscribeResp)
assert.NoError(t, err)
assert.Equal(t, 0, subscribeResp.Code, subscribeResp.Message)

// poll
b, _ = json.Marshal(TMQPollReq{ReqID: 100, BlockingTime: 500})
msg, err = doWebSocket(ws, TMQPoll, b)
assert.NoError(t, err)
var pollResp TMQPollResp
err = json.Unmarshal(msg, &pollResp)
assert.NoError(t, err)
assert.Equal(t, 0, pollResp.Code, string(msg))
for {
// poll until no message and no error
b, _ = json.Marshal(TMQPollReq{ReqID: 101, BlockingTime: 500})
msg, err = doWebSocket(ws, TMQPoll, b)
assert.NoError(t, err)
err = json.Unmarshal(msg, &pollResp)
assert.NoError(t, err)
if pollResp.Code != 0 {
t.Errorf("poll error: %s", pollResp.Message)
return
}
if !pollResp.HaveMessage {
break
}
}
t.Log("sleep 5s to wait for timeout")
// sleep
time.Sleep(time.Second * 5)
// poll
b, _ = json.Marshal(TMQPollReq{ReqID: 102, BlockingTime: 500})
msg, err = doWebSocket(ws, TMQPoll, b)
assert.NoError(t, err)
err = json.Unmarshal(msg, &pollResp)
assert.NoError(t, err)
assert.NotEqual(t, 0, pollResp.Code, string(msg))
}
14 changes: 11 additions & 3 deletions db/asynctmq/tmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ typedef struct tmq_thread {
pthread_t thread;
} tmq_thread;
typedef void (*adapter_tmq_poll_a_cb)(uintptr_t param, void *res);
typedef void (*adapter_tmq_poll_a_cb)(uintptr_t param, void *res, int32_t code, const char* errstr);
typedef void(*adapter_tmq_free_result_cb)(uintptr_t param);
Expand Down Expand Up @@ -201,7 +201,15 @@ void *worker(void *arg) {
case TAOSA_TMQ_POLL: {
poll_task *task = (poll_task *) thread_info->task;
void *res = tmq_consumer_poll(task->tmq, task->timeout);
task->cb(task->param, res);
if (res == NULL) {
int32_t code = taos_errno(NULL);
if (code != 0) {
const char *errstr = taos_errstr(NULL);
task->cb(task->param, res, code, errstr);
break;
}
}
task->cb(task->param, res, 0, NULL);
break;
}
case TAOSA_TMQ_FREE: {
Expand Down Expand Up @@ -545,7 +553,7 @@ void adapter_tmq_position(tmq_thread *thread_info, tmq_t *tmq, char *topic_name,
pthread_mutex_unlock(&(thread_info->lock));
}
extern void AdapterTMQPollCallback(uintptr_t param, void *res);
extern void AdapterTMQPollCallback(uintptr_t param, void *res, int32_t code, const char* errstr);
extern void AdapterTMQFreeResultCallback(uintptr_t param);
Expand Down
14 changes: 11 additions & 3 deletions db/asynctmq/tmq_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ typedef struct tmq_thread {
HANDLE thread;
} tmq_thread;
typedef void (*adapter_tmq_poll_a_cb)(uintptr_t param, void *res);
typedef void (*adapter_tmq_poll_a_cb)(uintptr_t param, void *res, int32_t code, const char* errstr);
typedef void(*adapter_tmq_free_result_cb)(uintptr_t param);
Expand Down Expand Up @@ -207,7 +207,15 @@ unsigned int worker(void *arg) {
case TAOSA_TMQ_POLL: {
poll_task *task = (poll_task *) thread_info->task;
void *res = tmq_consumer_poll(task->tmq, task->timeout);
task->cb(task->param, res);
if (res == NULL) {
int32_t code = taos_errno(NULL);
if (code != 0) {
const char *errstr = taos_errstr(NULL);
task->cb(task->param, res, code, errstr);
break;
}
}
task->cb(task->param, res, 0, NULL);
break;
}
case TAOSA_TMQ_FREE: {
Expand Down Expand Up @@ -558,7 +566,7 @@ void adapter_tmq_position(tmq_thread *thread_info, tmq_t *tmq, char *topic_name,
LeaveCriticalSection(&(thread_info->lock));
}
extern void AdapterTMQPollCallback(uintptr_t param, void *res);
extern void AdapterTMQPollCallback(uintptr_t param, void *res, int32_t code, const char* errstr);
extern void AdapterTMQFreeResultCallback(uintptr_t param);
Expand Down
4 changes: 2 additions & 2 deletions db/asynctmq/tmqcb.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import (
)

//export AdapterTMQPollCallback
func AdapterTMQPollCallback(handle C.uintptr_t, res *C.TAOS_RES) {
func AdapterTMQPollCallback(handle C.uintptr_t, res *C.TAOS_RES, code int, errstr *C.char) {
h := cgo.Handle(handle)
caller := h.Value().(*tmqhandle.TMQCaller)
caller.PollCall(unsafe.Pointer(res))
caller.PollCall(unsafe.Pointer(res), int32(code), C.GoString(errstr))
}

//export AdapterTMQFreeResultCallback
Expand Down
18 changes: 14 additions & 4 deletions db/asynctmq/tmqhandle/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ type ListTopicsResult struct {
Topics []string
}

type PollResult struct {
Code int32
Res unsafe.Pointer
ErrStr string
}

type TMQCaller struct {
PollResult chan unsafe.Pointer
PollResult chan *PollResult
FreeResult chan struct{}
CommitResult chan int32
SubscribeResult chan int32
Expand All @@ -50,7 +56,7 @@ type TMQCaller struct {

func NewTMQCaller() *TMQCaller {
return &TMQCaller{
PollResult: make(chan unsafe.Pointer, 1),
PollResult: make(chan *PollResult, 1),
FreeResult: make(chan struct{}, 1),
CommitResult: make(chan int32, 1),
SubscribeResult: make(chan int32, 1),
Expand All @@ -68,8 +74,12 @@ func NewTMQCaller() *TMQCaller {
}
}

func (c *TMQCaller) PollCall(res unsafe.Pointer) {
c.PollResult <- res
func (c *TMQCaller) PollCall(res unsafe.Pointer, code int32, errStr string) {
c.PollResult <- &PollResult{
Code: code,
Res: res,
ErrStr: errStr,
}
}

func (c *TMQCaller) FreeCall() {
Expand Down
10 changes: 6 additions & 4 deletions db/asynctmq/tmqhandle/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ func TestTMQHandlerPool(t *testing.T) {

caller := handler.Caller
pollRes := unsafe.Pointer(&struct{}{})
caller.PollCall(pollRes)
if <-caller.PollResult != pollRes {
caller.PollCall(pollRes, 0, "")
res := <-caller.PollResult
if res.Res != pollRes {
t.Errorf("PollCall failed")
}

Expand Down Expand Up @@ -72,8 +73,9 @@ func TestTMQCaller(t *testing.T) {
// Test PollCall
a := 1
res := unsafe.Pointer(&a)
caller.PollCall(res)
if <-caller.PollResult != res {
caller.PollCall(res, 0, "")
pollRes := <-caller.PollResult
if pollRes.Res != res {
t.Error("PollCall failed")
}

Expand Down

0 comments on commit 169e50e

Please sign in to comment.