diff --git a/Makefile b/Makefile index 776da60..379c09c 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ test: build - go test ./... -count=1 --race + go test ./... -count=1 --race --timeout=5s proto: protoc --go_out=. --go-vtproto_out=. --go_opt=paths=source_relative --proto_path=. actor/actor.proto diff --git a/actor/context_test.go b/actor/context_test.go index e3229e3..3f6ffe9 100644 --- a/actor/context_test.go +++ b/actor/context_test.go @@ -1,6 +1,7 @@ package actor import ( + fmt "fmt" "sync" "testing" "time" @@ -21,7 +22,7 @@ func TestChildEventNoRaceCondition(t *testing.T) { c.engine.Subscribe(child) } }, "parent") - e.Poison(parentPID).Wait() + <-e.Poison(parentPID).Done() } func TestContextSendRepeat(t *testing.T) { @@ -75,8 +76,7 @@ func TestSpawnChildPID(t *testing.T) { func TestChild(t *testing.T) { var ( - wg = sync.WaitGroup{} - stopWg = sync.WaitGroup{} + wg = sync.WaitGroup{} ) e, err := NewEngine(NewEngineConfig()) require.NoError(t, err) @@ -89,8 +89,9 @@ func TestChild(t *testing.T) { c.SpawnChildFunc(func(_ *Context) {}, "child", WithID("3")) case Started: assert.Equal(t, 3, len(c.Children())) - c.Engine().Stop(c.Children()[0], &stopWg) - stopWg.Wait() + childPid := c.Children()[0] + fmt.Println("sending poison pill to ", childPid) + <-c.Engine().Stop(childPid).Done() assert.Equal(t, 2, len(c.Children())) wg.Done() } @@ -164,7 +165,7 @@ func TestSpawnChild(t *testing.T) { }, "parent", WithMaxRestarts(0)) wg.Wait() - e.Poison(pid).Wait() + <-e.Poison(pid).Done() assert.Nil(t, e.Registry.get(NewPID("local", "child"))) assert.Nil(t, e.Registry.get(pid)) diff --git a/actor/engine.go b/actor/engine.go index efe9fda..cb484e1 100644 --- a/actor/engine.go +++ b/actor/engine.go @@ -1,6 +1,7 @@ package actor import ( + "context" "fmt" "math" "math/rand" @@ -205,27 +206,31 @@ func (e *Engine) SendRepeat(pid *PID, msg any, interval time.Duration) SendRepea } // Stop will send a non-graceful poisonPill message to the process that is associated with the given PID. -// The process will shut down immediately, once it has processed the poisonPill messsage. -func (e *Engine) Stop(pid *PID, wg ...*sync.WaitGroup) *sync.WaitGroup { - return e.sendPoisonPill(pid, false, wg...) +// The process will shut down immediately. A context is being returned that can be used to block / wait +// until the process is stopped. +func (e *Engine) Stop(pid *PID) context.Context { + return e.sendPoisonPill(context.Background(), false, pid) } // Poison will send a graceful poisonPill message to the process that is associated with the given PID. // The process will shut down gracefully once it has processed all the messages in the inbox. -// If given a WaitGroup, it blocks till the process is completely shutdown. -func (e *Engine) Poison(pid *PID, wg ...*sync.WaitGroup) *sync.WaitGroup { - return e.sendPoisonPill(pid, true, wg...) +// A context is returned that can be used to block / wait until the process is stopped. +func (e *Engine) Poison(pid *PID) context.Context { + return e.sendPoisonPill(context.Background(), true, pid) } -func (e *Engine) sendPoisonPill(pid *PID, graceful bool, wg ...*sync.WaitGroup) *sync.WaitGroup { - var _wg *sync.WaitGroup - if len(wg) > 0 { - _wg = wg[0] - } else { - _wg = &sync.WaitGroup{} - } +// PoisonCtx behaves the exact same as Poison, the only difference is that it accepts +// a context as the first argument. The context can be used for custom timeouts and manual +// cancelation. +func (e *Engine) PoisonCtx(ctx context.Context, pid *PID) context.Context { + return e.sendPoisonPill(ctx, true, pid) +} + +func (e *Engine) sendPoisonPill(ctx context.Context, graceful bool, pid *PID) context.Context { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) pill := poisonPill{ - wg: _wg, + cancel: cancel, graceful: graceful, } // deadletter - if we didn't find a process, we will broadcast a DeadletterEvent @@ -235,11 +240,11 @@ func (e *Engine) sendPoisonPill(pid *PID, graceful bool, wg ...*sync.WaitGroup) Message: pill, Sender: nil, }) - return _wg + cancel() + return ctx } - _wg.Add(1) e.SendLocal(pid, pill, nil) - return _wg + return ctx } // SendLocal will send the given message to the given PID. If the recipient is not found in the diff --git a/actor/engine_test.go b/actor/engine_test.go index 3f61629..a58f36e 100644 --- a/actor/engine_test.go +++ b/actor/engine_test.go @@ -95,7 +95,7 @@ func TestRestartsMaxRestarts(t *testing.T) { case Started: case Stopped: case payload: - if msg.data != 10 { + if msg.data != 1 { panic("I failed to process this message") } else { fmt.Println("finally processed all my messages after borking.", msg.data) @@ -103,9 +103,10 @@ func TestRestartsMaxRestarts(t *testing.T) { } }, "foo", WithMaxRestarts(restarts)) - for i := 0; i < 11; i++ { + for i := 0; i < 2; i++ { e.Send(pid, payload{i}) } + <-e.Poison(pid).Done() } func TestProcessInitStartOrder(t *testing.T) { @@ -249,7 +250,7 @@ func TestSpawnDuplicateId(t *testing.T) { assert.Equal(t, int32(1), startsCount) // should only spawn one actor } -func TestStopWaitGroup(t *testing.T) { +func TestStopWaitContextDone(t *testing.T) { var ( wg = sync.WaitGroup{} x = int32(0) @@ -268,7 +269,7 @@ func TestStopWaitGroup(t *testing.T) { }, "foo") wg.Wait() - e.Stop(pid).Wait() + <-e.Stop(pid).Done() assert.Equal(t, int32(1), atomic.LoadInt32(&x)) } @@ -290,7 +291,7 @@ func TestStop(t *testing.T) { }, "foo", WithID(tag)) wg.Wait() - e.Stop(pid).Wait() + <-e.Stop(pid).Done() // When a process is poisoned it should be removed from the registry. // Hence, we should get nil when looking it up in the registry. assert.Nil(t, e.Registry.get(pid)) @@ -316,11 +317,11 @@ func TestPoisonWaitGroup(t *testing.T) { }, "foo") wg.Wait() - e.Poison(pid).Wait() + <-e.Poison(pid).Done() assert.Equal(t, int32(1), atomic.LoadInt32(&x)) // validate poisoning non exiting pid does not deadlock - wg = e.Poison(NewPID(LocalLookupAddr, "non-existing")) + e.Poison(NewPID(LocalLookupAddr, "non-existing")) done := make(chan struct{}) go func() { defer close(done) @@ -331,8 +332,6 @@ func TestPoisonWaitGroup(t *testing.T) { case <-time.After(20 * time.Millisecond): t.Error("poison waitGroup deadlocked") } - // ... or panic - e.Poison(nil).Wait() } func TestPoison(t *testing.T) { @@ -353,7 +352,7 @@ func TestPoison(t *testing.T) { }, "foo", WithID(tag)) wg.Wait() - e.Poison(pid).Wait() + <-e.Poison(pid).Done() // When a process is poisoned it should be removed from the registry. // Hence, we should get NIL when we try to get it. assert.Nil(t, e.Registry.get(pid)) @@ -407,7 +406,7 @@ func TestPoisonPillPrivate(t *testing.T) { time.Sleep(time.Millisecond) } }, "victim") - e.Poison(pid).Wait() + <-e.Poison(pid).Done() assert.Nil(t, e.Registry.get(pid)) select { case <-failCh: diff --git a/actor/event_stream_test.go b/actor/event_stream_test.go index 4d39210..03b9460 100644 --- a/actor/event_stream_test.go +++ b/actor/event_stream_test.go @@ -75,7 +75,7 @@ func TestEventStreamActorStoppedEvent(t *testing.T) { }, "b") e.Subscribe(pidb) - e.Poison(a).Wait() + <-e.Poison(a).Done() wg.Wait() } diff --git a/actor/event_test.go b/actor/event_test.go index f1a4eed..611e92a 100644 --- a/actor/event_test.go +++ b/actor/event_test.go @@ -23,5 +23,5 @@ func TestDuplicateIdEvent(t *testing.T) { e.SpawnFunc(func(c *Context) {}, "actor_a", WithID("1")) e.SpawnFunc(func(c *Context) {}, "actor_a", WithID("1")) wg.Wait() - e.Poison(monitor).Wait() + <-e.Poison(monitor).Done() } diff --git a/actor/inbox_test.go b/actor/inbox_test.go index 678fd5f..38e7dd8 100644 --- a/actor/inbox_test.go +++ b/actor/inbox_test.go @@ -1,7 +1,6 @@ package actor import ( - "sync" "sync/atomic" "testing" "time" @@ -70,7 +69,7 @@ func (m MockProcesser) Send(*PID, any, *PID) {} func (m MockProcesser) Invoke(envelopes []Envelope) { m.processFunc(envelopes) } -func (m MockProcesser) Shutdown(_ *sync.WaitGroup) {} +func (m MockProcesser) Shutdown() {} func TestInboxStop(t *testing.T) { inbox := NewInbox(10) diff --git a/actor/process.go b/actor/process.go index 3de42b7..41ee192 100644 --- a/actor/process.go +++ b/actor/process.go @@ -2,10 +2,10 @@ package actor import ( "bytes" + "context" "fmt" "log/slog" "runtime/debug" - "sync" "time" "github.com/DataDog/gostackparse" @@ -22,7 +22,7 @@ type Processer interface { PID() *PID Send(*PID, any, *PID) Invoke([]Envelope) - Shutdown(*sync.WaitGroup) + Shutdown() } type process struct { @@ -80,6 +80,7 @@ func (p *process) Invoke(msgs []Envelope) { p.tryRestart(v) } }() + for i := 0; i < len(msgs); i++ { nproc++ msg := msgs[i] @@ -92,7 +93,7 @@ func (p *process) Invoke(msgs []Envelope) { p.invokeMsg(m) } } - p.cleanup(pill.wg) + p.cleanup(pill.cancel) return } p.invokeMsg(msg) @@ -178,7 +179,9 @@ func (p *process) tryRestart(v any) { p.Start() } -func (p *process) cleanup(wg *sync.WaitGroup) { +func (p *process) cleanup(cancel context.CancelFunc) { + defer cancel() + if p.context.parentCtx != nil { p.context.parentCtx.children.Delete(p.pid.ID) } @@ -186,7 +189,7 @@ func (p *process) cleanup(wg *sync.WaitGroup) { if p.context.children.Len() > 0 { children := p.context.Children() for _, pid := range children { - p.context.engine.Poison(pid).Wait() + <-p.context.engine.Poison(pid).Done() } } @@ -196,16 +199,15 @@ func (p *process) cleanup(wg *sync.WaitGroup) { applyMiddleware(p.context.receiver.Receive, p.Opts.Middleware...)(p.context) p.context.engine.BroadcastEvent(ActorStoppedEvent{PID: p.pid, Timestamp: time.Now()}) - if wg != nil { - wg.Done() - } } func (p *process) PID() *PID { return p.pid } func (p *process) Send(_ *PID, msg any, sender *PID) { p.inbox.Send(Envelope{Msg: msg, Sender: sender}) } -func (p *process) Shutdown(wg *sync.WaitGroup) { p.cleanup(wg) } +func (p *process) Shutdown() { + p.cleanup(nil) +} func cleanTrace(stack []byte) []byte { goros, err := gostackparse.Parse(bytes.NewReader(stack)) diff --git a/actor/registry_test.go b/actor/registry_test.go index 9914fff..43ed047 100644 --- a/actor/registry_test.go +++ b/actor/registry_test.go @@ -1,7 +1,6 @@ package actor import ( - "sync" "testing" "github.com/stretchr/testify/assert" @@ -10,11 +9,11 @@ import ( type fooProc struct { } -func (p fooProc) Start() {} -func (p fooProc) PID() *PID { return NewPID(LocalLookupAddr, "foo") } -func (p fooProc) Send(*PID, any, *PID) {} -func (p fooProc) Invoke([]Envelope) {} -func (p fooProc) Shutdown(*sync.WaitGroup) {} +func (p fooProc) Start() {} +func (p fooProc) PID() *PID { return NewPID(LocalLookupAddr, "foo") } +func (p fooProc) Send(*PID, any, *PID) {} +func (p fooProc) Invoke([]Envelope) {} +func (p fooProc) Shutdown() {} func TestGetRemoveAdd(t *testing.T) { e, _ := NewEngine(NewEngineConfig()) diff --git a/actor/response.go b/actor/response.go index 61a8046..c6493f7 100644 --- a/actor/response.go +++ b/actor/response.go @@ -5,7 +5,6 @@ import ( "math" "math/rand" "strconv" - "sync" "time" ) @@ -44,7 +43,7 @@ func (r *Response) Send(_ *PID, msg any, _ *PID) { r.result <- msg } -func (r *Response) PID() *PID { return r.pid } -func (r *Response) Shutdown(_ *sync.WaitGroup) {} -func (r *Response) Start() {} -func (r *Response) Invoke([]Envelope) {} +func (r *Response) PID() *PID { return r.pid } +func (r *Response) Shutdown() {} +func (r *Response) Start() {} +func (r *Response) Invoke([]Envelope) {} diff --git a/actor/types.go b/actor/types.go index 259a014..843590b 100644 --- a/actor/types.go +++ b/actor/types.go @@ -1,6 +1,6 @@ package actor -import "sync" +import "context" type InternalError struct { From string @@ -8,7 +8,7 @@ type InternalError struct { } type poisonPill struct { - wg *sync.WaitGroup + cancel context.CancelFunc graceful bool } type Initialized struct{} diff --git a/cluster/cluster.go b/cluster/cluster.go index 0c4149d..1086924 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -6,7 +6,6 @@ import ( "math" "math/rand" "reflect" - "sync" "time" "github.com/anthdm/hollywood/actor" @@ -131,11 +130,9 @@ func (c *Cluster) Start() { } // Stop will shutdown the cluster poisoning all its actors. -func (c *Cluster) Stop() *sync.WaitGroup { - wg := sync.WaitGroup{} - c.engine.Poison(c.agentPID, &wg) - c.engine.Poison(c.providerPID, &wg) - return &wg +func (c *Cluster) Stop() { + <-c.engine.Poison(c.agentPID).Done() + <-c.engine.Poison(c.providerPID).Done() } // Spawn an actor locally on the node with cluster awareness. diff --git a/cluster/cluster_test.go b/cluster/cluster_test.go index f0d1df9..898f485 100644 --- a/cluster/cluster_test.go +++ b/cluster/cluster_test.go @@ -78,9 +78,9 @@ func TestClusterSelectMemberFunc(t *testing.T) { <-ctx.Done() require.Equal(t, context.DeadlineExceeded, ctx.Err()) - c1.Stop().Wait() - c2.Stop().Wait() - c3.Stop().Wait() + c1.Stop() + c2.Stop() + c3.Stop() } func TestClusterShouldWorkWithDefaultValues(t *testing.T) { @@ -135,8 +135,8 @@ func TestClusterSpawn(t *testing.T) { c2.Start() wg.Wait() - c1.Stop().Wait() - c2.Stop().Wait() + c1.Stop() + c2.Stop() } func TestMemberJoin(t *testing.T) { @@ -165,8 +165,8 @@ func TestMemberJoin(t *testing.T) { assert.Equal(t, len(c1.Members()), 2) assert.True(t, c1.HasKind("player")) - c1.Stop().Wait() - c2.Stop().Wait() + c1.Stop() + c2.Stop() } func TestActivate(t *testing.T) { @@ -203,8 +203,8 @@ func TestActivate(t *testing.T) { assert.True(t, c1.HasKind("player")) assert.True(t, c1.GetActivated("player/1").Equals(expectedPID)) - c1.Stop().Wait() - c2.Stop().Wait() + c1.Stop() + c2.Stop() } func TestDeactivate(t *testing.T) { @@ -239,8 +239,8 @@ func TestDeactivate(t *testing.T) { assert.True(t, c1.HasKind("player")) assert.Nil(t, c1.GetActivated("player/1")) - c1.Stop().Wait() - c2.Stop().Wait() + c1.Stop() + c2.Stop() } func TestMemberLeave(t *testing.T) { @@ -283,8 +283,8 @@ func TestMemberLeave(t *testing.T) { assert.Equal(t, len(c1.Members()), 1) assert.False(t, c1.HasKind("player")) - c1.Stop().Wait() - c2.Stop().Wait() + c1.Stop() + c2.Stop() } func TestMembersExcept(t *testing.T) { diff --git a/examples/chat/client/main.go b/examples/chat/client/main.go index 23bd2e1..01217bc 100644 --- a/examples/chat/client/main.go +++ b/examples/chat/client/main.go @@ -86,6 +86,6 @@ func main() { // When breaked out of the loop on error let the server know // we need to disconnect. e.SendWithSender(serverPID, &types.Disconnect{}, clientPID) - e.Poison(clientPID).Wait() + <-e.Poison(clientPID).Done() slog.Info("client disconnected") } diff --git a/examples/eventstream-monitor/main.go b/examples/eventstream-monitor/main.go index 78d0509..ec4b65f 100644 --- a/examples/eventstream-monitor/main.go +++ b/examples/eventstream-monitor/main.go @@ -79,7 +79,7 @@ func main() { repeater := e.SendRepeat(ua, customMessage{}, time.Millisecond*20) time.Sleep(time.Second * 5) repeater.Stop() - e.Poison(ua).Wait() + <-e.Poison(ua).Done() res, err := e.Request(monitor, query{}, time.Second).Result() if err != nil { fmt.Println("Query", err) @@ -89,6 +89,6 @@ func main() { fmt.Printf("Observed %d starts\n", q.starts) fmt.Printf("Observed %d stops\n", q.stops) fmt.Printf("Observed %d deadletters\n", q.deadletters) - e.Poison(monitor).Wait() // the monitor will output stats on stop. + <-e.Poison(monitor).Done() // the monitor will output stats on stop. fmt.Println("done") } diff --git a/examples/helloworld/main.go b/examples/helloworld/main.go index 461961f..7d9ac31 100644 --- a/examples/helloworld/main.go +++ b/examples/helloworld/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/anthdm/hollywood/actor" @@ -27,26 +28,17 @@ func (f *foo) Receive(ctx *actor.Context) { } } -type info struct { - data string -} - func main() { - - var ds []*info - ds = append(ds, &info{data: "hello"}) - - fmt.Printf("ds len: %d\n", len(ds)) - - fmt.Printf("%+v\n", ds) engine, err := actor.NewEngine(actor.NewEngineConfig()) if err != nil { panic(err) } pid := engine.Spawn(newFoo, "my_actor") - for i := 0; i < 3; i++ { + for i := 0; i < 5; i++ { engine.Send(pid, &message{data: "hello world!"}) } - engine.Poison(pid).Wait() + + ctx := engine.PoisonCtx(context.Background(), pid) + <-ctx.Done() } diff --git a/examples/middleware/hooks/main.go b/examples/middleware/hooks/main.go index a42f5b7..9297090 100644 --- a/examples/middleware/hooks/main.go +++ b/examples/middleware/hooks/main.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "sync" "time" "github.com/anthdm/hollywood/actor" @@ -54,8 +53,5 @@ func main() { e.Send(pid, "Hello sailor!") // We sleep here so we are sure foo received our message time.Sleep(time.Second) - // Create a waitgroup so we can wait until foo has been stopped gracefully - wg := &sync.WaitGroup{} - e.Poison(pid, wg) - wg.Wait() + <-e.Poison(pid).Done() } diff --git a/examples/persistance/main.go b/examples/persistance/main.go index 602fbec..e7628c6 100644 --- a/examples/persistance/main.go +++ b/examples/persistance/main.go @@ -7,7 +7,6 @@ import ( "os" "path" "regexp" - "sync" "time" "github.com/anthdm/hollywood/actor" @@ -158,9 +157,7 @@ func main() { time.Sleep(time.Second * 1) e.Send(pid, TakeDamage{Amount: 9}) time.Sleep(time.Second * 1) - wg := &sync.WaitGroup{} - e.Poison(pid, wg) - wg.Wait() + <-e.Poison(pid).Done() } var safeRx = regexp.MustCompile(`[^a-zA-Z0-9]`) diff --git a/examples/tcpserver/main.go b/examples/tcpserver/main.go index 1612db6..f075f1a 100644 --- a/examples/tcpserver/main.go +++ b/examples/tcpserver/main.go @@ -7,7 +7,6 @@ import ( "net" "os" "os/signal" - "sync" "syscall" "time" @@ -152,8 +151,5 @@ func main() { signal.Notify(sigch, syscall.SIGINT, syscall.SIGTERM) <-sigch - // wait till the server is gracefully shutdown by using a WaitGroup in the Poison call. - wg := &sync.WaitGroup{} - e.Poison(serverPID, wg) - wg.Wait() + <-e.Poison(serverPID).Done() } diff --git a/examples/trade-engine/main.go b/examples/trade-engine/main.go index ce27f18..3ed52b6 100644 --- a/examples/trade-engine/main.go +++ b/examples/trade-engine/main.go @@ -68,7 +68,7 @@ func main() { e.Send(tradeEnginePID, types.CancelOrderRequest{TradeID: tradeOrder.TradeID}) time.Sleep(2 * time.Second) - e.Poison(tradeEnginePID).Wait() + <-e.Poison(tradeEnginePID).Done() } // The GenID function generates a random ID string of length 16 using a cryptographic random number diff --git a/remote/remote_tls_test.go b/remote/remote_tls_test.go index a310caa..df450f0 100644 --- a/remote/remote_tls_test.go +++ b/remote/remote_tls_test.go @@ -63,8 +63,8 @@ func TestSend_TLS(t *testing.T) { wg.Wait() // wait for messages to be received by the actor. ra.Stop().Wait() // shutdown the remotes rb.Stop().Wait() - a.Poison(pida).Wait() - b.Poison(pidb).Wait() + <-a.Poison(pida).Done() + <-b.Poison(pidb).Done() } func generateTLSConfig() (*sharedConfig, error) { diff --git a/remote/stream_writer.go b/remote/stream_writer.go index ab237da..6c7dd24 100644 --- a/remote/stream_writer.go +++ b/remote/stream_writer.go @@ -7,7 +7,6 @@ import ( "io" "log/slog" "net" - "sync" "time" "github.com/anthdm/hollywood/actor" @@ -141,7 +140,7 @@ func (s *streamWriter) init() { // We could not reach the remote after retrying N times. Hence, shutdown the stream writer. // and notify RemoteUnreachableEvent. if rawconn == nil { - s.Shutdown(nil) + s.Shutdown() return } @@ -158,7 +157,7 @@ func (s *streamWriter) init() { stream, err := client.Receive(context.Background()) if err != nil { slog.Error("receive", "err", err, "remote", s.writeToAddr) - s.Shutdown(nil) + s.Shutdown() return } @@ -174,13 +173,13 @@ func (s *streamWriter) init() { slog.Debug("lost connection", "remote", s.writeToAddr, ) - s.Shutdown(nil) + s.Shutdown() }() } // TODO: is there a way that stream router can listen to event stream // instead of sending the event itself? -func (s *streamWriter) Shutdown(wg *sync.WaitGroup) { +func (s *streamWriter) Shutdown() { evt := actor.RemoteUnreachableEvent{ListenAddr: s.writeToAddr} s.engine.Send(s.routerPID, evt) s.engine.BroadcastEvent(evt) @@ -189,9 +188,6 @@ func (s *streamWriter) Shutdown(wg *sync.WaitGroup) { } s.inbox.Stop() s.engine.Registry.Remove(s.PID()) - if wg != nil { - wg.Done() - } } func (s *streamWriter) Start() {