From cbe39cd7dfc6783002f087bc4ddf9ac1ce43b158 Mon Sep 17 00:00:00 2001 From: Jorropo Date: Thu, 15 Jun 2023 02:47:17 +0200 Subject: [PATCH] refactor: remove goprocess --- dht.go | 214 ++++++++++++++-------------- dht_bootstrap_test.go | 4 +- dht_test.go | 73 +++++----- fullrt/dht.go | 5 +- go.mod | 4 +- handlers_test.go | 2 +- internal/net/message_manager.go | 2 +- pb/protocol_messenger.go | 6 + providers/providers_manager.go | 210 ++++++++++++++------------- providers/providers_manager_test.go | 20 +-- query.go | 2 +- rtrefresh/rt_refresh_manager.go | 17 +-- subscriber_notifee.go | 149 +++++++------------ 13 files changed, 328 insertions(+), 380 deletions(-) diff --git a/dht.go b/dht.go index 88f0e899b..da3e74d83 100644 --- a/dht.go +++ b/dht.go @@ -33,11 +33,10 @@ import ( "github.com/gogo/protobuf/proto" ds "github.com/ipfs/go-datastore" logging "github.com/ipfs/go-log" - "github.com/jbenet/goprocess" - goprocessctx "github.com/jbenet/goprocess/context" "github.com/multiformats/go-base32" ma "github.com/multiformats/go-multiaddr" "go.opencensus.io/tag" + "go.uber.org/multierr" "go.uber.org/zap" ) @@ -92,13 +91,12 @@ type IpfsDHT struct { Validator record.Validator - ctx context.Context - proc goprocess.Process + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup protoMessenger *pb.ProtocolMessenger - msgSender pb.MessageSender - - plk sync.Mutex + msgSender pb.MessageSenderWithDisconnect stripedPutLocks [256]sync.Mutex @@ -187,7 +185,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) return nil, err } - dht, err := makeDHT(ctx, h, cfg) + dht, err := makeDHT(h, cfg) if err != nil { return nil, fmt.Errorf("failed to create DHT, err=%s", err) } @@ -225,30 +223,27 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) } // register for event bus and network notifications - sn, err := newSubscriberNotifiee(dht) - if err != nil { + if err := dht.startNetworkSubscriber(); err != nil { return nil, err } - dht.proc.Go(sn.subscribe) - // handle providers - if mgr, ok := dht.providerStore.(interface{ Process() goprocess.Process }); ok { - dht.proc.AddChild(mgr.Process()) - } // go-routine to make sure we ALWAYS have RT peer addresses in the peerstore // since RT membership is decoupled from connectivity go dht.persistRTPeersInPeerStore() - dht.proc.Go(dht.rtPeerLoop) + dht.rtPeerLoop() // Fill routing table with currently connected peers that are DHT servers - dht.plk.Lock() for _, p := range dht.host.Network().Peers() { - dht.peerFound(dht.ctx, p) + dht.peerFound(p) } - dht.plk.Unlock() - dht.proc.Go(dht.populatePeers) + dht.rtRefreshManager.Start() + + // listens to the fix low peers chan and tries to fix the Routing Table + if !dht.disableFixLowPeers { + dht.runFixLowPeersLoop() + } return dht, nil } @@ -275,7 +270,7 @@ func NewDHTClient(ctx context.Context, h host.Host, dstore ds.Batching) *IpfsDHT return dht } -func makeDHT(ctx context.Context, h host.Host, cfg dhtcfg.Config) (*IpfsDHT, error) { +func makeDHT(h host.Host, cfg dhtcfg.Config) (*IpfsDHT, error) { var protocols, serverProtocols []protocol.ID v1proto := cfg.ProtocolPrefix + kad1 @@ -346,26 +341,19 @@ func makeDHT(ctx context.Context, h host.Host, cfg dhtcfg.Config) (*IpfsDHT, err } // rt refresh manager - rtRefresh, err := makeRtRefreshManager(dht, cfg, maxLastSuccessfulOutboundThreshold) + dht.rtRefreshManager, err = makeRtRefreshManager(dht, cfg, maxLastSuccessfulOutboundThreshold) if err != nil { return nil, fmt.Errorf("failed to construct RT Refresh Manager,err=%s", err) } - dht.rtRefreshManager = rtRefresh - - // create a DHT proc with the given context - dht.proc = goprocessctx.WithContextAndTeardown(ctx, func() error { - return rtRefresh.Close() - }) // create a tagged context derived from the original context - ctxTags := dht.newContextWithLocalTags(ctx) // the DHT context should be done when the process is closed - dht.ctx = goprocessctx.WithProcessClosing(ctxTags, dht.proc) + dht.ctx, dht.cancel = context.WithCancel(dht.newContextWithLocalTags(context.Background())) if cfg.ProviderStore != nil { dht.providerStore = cfg.ProviderStore } else { - dht.providerStore, err = providers.NewProviderManager(dht.ctx, h.ID(), dht.peerstore, cfg.Datastore) + dht.providerStore, err = providers.NewProviderManager(h.ID(), dht.peerstore, cfg.Datastore) if err != nil { return nil, fmt.Errorf("initializing default provider manager (%v)", err) } @@ -468,42 +456,32 @@ func (dht *IpfsDHT) Mode() ModeOpt { return dht.auto } -func (dht *IpfsDHT) populatePeers(_ goprocess.Process) { - if !dht.disableFixLowPeers { - dht.fixLowPeers(dht.ctx) - } +// runFixLowPeersLoop manages simultaneous requests to fixLowPeers +func (dht *IpfsDHT) runFixLowPeersLoop() { + dht.wg.Add(1) + go func() { + defer dht.wg.Done() - if err := dht.rtRefreshManager.Start(); err != nil { - logger.Error(err) - } + dht.fixLowPeers() - // listens to the fix low peers chan and tries to fix the Routing Table - if !dht.disableFixLowPeers { - dht.proc.Go(dht.fixLowPeersRoutine) - } - -} + ticker := time.NewTicker(periodicBootstrapInterval) + defer ticker.Stop() -// fixLowPeersRouting manages simultaneous requests to fixLowPeers -func (dht *IpfsDHT) fixLowPeersRoutine(proc goprocess.Process) { - ticker := time.NewTicker(periodicBootstrapInterval) - defer ticker.Stop() + for { + select { + case <-dht.fixLowPeersChan: + case <-ticker.C: + case <-dht.ctx.Done(): + return + } - for { - select { - case <-dht.fixLowPeersChan: - case <-ticker.C: - case <-proc.Closing(): - return + dht.fixLowPeers() } - - dht.fixLowPeers(dht.Context()) - } - + }() } // fixLowPeers tries to get more peers into the routing table if we're below the threshold -func (dht *IpfsDHT) fixLowPeers(ctx context.Context) { +func (dht *IpfsDHT) fixLowPeers() { if dht.routingTable.Size() > minRTRefreshThreshold { return } @@ -511,7 +489,7 @@ func (dht *IpfsDHT) fixLowPeers(ctx context.Context) { // we try to add all peers we are connected to to the Routing Table // in case they aren't already there. for _, p := range dht.host.Network().Peers() { - dht.peerFound(ctx, p) + dht.peerFound(p) } // TODO Active Bootstrapping @@ -528,7 +506,7 @@ func (dht *IpfsDHT) fixLowPeers(ctx context.Context) { found := 0 for _, i := range rand.Perm(len(bootstrapPeers)) { ai := bootstrapPeers[i] - err := dht.Host().Connect(ctx, ai) + err := dht.Host().Connect(dht.ctx, ai) if err == nil { found++ } else { @@ -613,54 +591,59 @@ func (dht *IpfsDHT) putLocal(ctx context.Context, key string, rec *recpb.Record) return dht.datastore.Put(ctx, mkDsKey(key), data) } -func (dht *IpfsDHT) rtPeerLoop(proc goprocess.Process) { - bootstrapCount := 0 - isBootsrapping := false - var timerCh <-chan time.Time +func (dht *IpfsDHT) rtPeerLoop() { + dht.wg.Add(1) + go func() { + defer dht.wg.Done() + + var bootstrapCount uint + var isBootsrapping bool + var timerCh <-chan time.Time + + for { + select { + case <-timerCh: + dht.routingTable.MarkAllPeersIrreplaceable() + case p := <-dht.addPeerToRTChan: + if dht.routingTable.Size() == 0 { + isBootsrapping = true + bootstrapCount = 0 + timerCh = nil + } + // queryPeer set to true as we only try to add queried peers to the RT + newlyAdded, err := dht.routingTable.TryAddPeer(p, true, isBootsrapping) + if err != nil { + // peer not added. + continue + } + if !newlyAdded { + // the peer is already in our RT, but we just successfully queried it and so let's give it a + // bump on the query time so we don't ping it too soon for a liveliness check. + dht.routingTable.UpdateLastSuccessfulOutboundQueryAt(p, time.Now()) + } + case <-dht.refreshFinishedCh: + bootstrapCount = bootstrapCount + 1 + if bootstrapCount == 2 { + timerCh = time.NewTimer(dht.rtFreezeTimeout).C + } - for { - select { - case <-timerCh: - dht.routingTable.MarkAllPeersIrreplaceable() - case p := <-dht.addPeerToRTChan: - if dht.routingTable.Size() == 0 { - isBootsrapping = true - bootstrapCount = 0 - timerCh = nil - } - // queryPeer set to true as we only try to add queried peers to the RT - newlyAdded, err := dht.routingTable.TryAddPeer(p, true, isBootsrapping) - if err != nil { - // peer not added. - continue - } - if !newlyAdded { - // the peer is already in our RT, but we just successfully queried it and so let's give it a - // bump on the query time so we don't ping it too soon for a liveliness check. - dht.routingTable.UpdateLastSuccessfulOutboundQueryAt(p, time.Now()) - } - case <-dht.refreshFinishedCh: - bootstrapCount = bootstrapCount + 1 - if bootstrapCount == 2 { - timerCh = time.NewTimer(dht.rtFreezeTimeout).C - } + old := isBootsrapping + isBootsrapping = false + if old { + dht.rtRefreshManager.RefreshNoWait() + } - old := isBootsrapping - isBootsrapping = false - if old { - dht.rtRefreshManager.RefreshNoWait() + case <-dht.ctx.Done(): + return } - - case <-proc.Closing(): - return } - } + }() } // peerFound verifies whether the found peer advertises DHT protocols // and probe it to make sure it answers DHT queries as expected. If // it fails to answer, it isn't added to the routingTable. -func (dht *IpfsDHT) peerFound(ctx context.Context, p peer.ID) { +func (dht *IpfsDHT) peerFound(p peer.ID) { // if the peer is already in the routing table or the appropriate bucket is // already full, don't try to add the new peer.ID if !dht.routingTable.UsefulNewPeer(p) { @@ -685,7 +668,7 @@ func (dht *IpfsDHT) peerFound(ctx context.Context, p peer.ID) { dht.lookupChecksLk.Unlock() go func() { - livelinessCtx, cancel := context.WithTimeout(ctx, dht.lookupCheckTimeout) + livelinessCtx, cancel := context.WithTimeout(dht.ctx, dht.lookupCheckTimeout) defer cancel() // performing a FIND_NODE query @@ -701,14 +684,14 @@ func (dht *IpfsDHT) peerFound(ctx context.Context, p peer.ID) { } // if the FIND_NODE succeeded, the peer is considered as valid - dht.validPeerFound(ctx, p) + dht.validPeerFound(p) }() } } // validPeerFound signals the routingTable that we've found a peer that // supports the DHT protocol, and just answered correctly to a DHT FindPeers -func (dht *IpfsDHT) validPeerFound(ctx context.Context, p peer.ID) { +func (dht *IpfsDHT) validPeerFound(p peer.ID) { if c := baseLogger.Check(zap.DebugLevel, "peer found"); c != nil { c.Write(zap.String("peer", p.String())) } @@ -852,11 +835,6 @@ func (dht *IpfsDHT) Context() context.Context { return dht.ctx } -// Process returns the DHT's process. -func (dht *IpfsDHT) Process() goprocess.Process { - return dht.proc -} - // RoutingTable returns the DHT's routingTable. func (dht *IpfsDHT) RoutingTable() *kb.RoutingTable { return dht.routingTable @@ -864,7 +842,25 @@ func (dht *IpfsDHT) RoutingTable() *kb.RoutingTable { // Close calls Process Close. func (dht *IpfsDHT) Close() error { - return dht.proc.Close() + dht.cancel() + dht.wg.Wait() + + var wg sync.WaitGroup + closes := [...]func() error{ + dht.rtRefreshManager.Close, + dht.providerStore.Close, + } + var errors [len(closes)]error + wg.Add(len(errors)) + for i, c := range closes { + go func(i int, c func() error) { + defer wg.Done() + errors[i] = c() + }(i, c) + } + wg.Wait() + + return multierr.Combine(errors[:]...) } func mkDsKey(s string) ds.Key { diff --git a/dht_bootstrap_test.go b/dht_bootstrap_test.go index 9dd496c0a..e2236f5a1 100644 --- a/dht_bootstrap_test.go +++ b/dht_bootstrap_test.go @@ -191,8 +191,8 @@ func TestBootstrappersReplacable(t *testing.T) { require.NoError(t, d.host.Network().ClosePeer(d5.self)) connectNoSync(t, ctx, d, d1) connectNoSync(t, ctx, d, d5) - d.peerFound(ctx, d5.self) - d.peerFound(ctx, d1.self) + d.peerFound(d5.self) + d.peerFound(d1.self) time.Sleep(1 * time.Second) require.Len(t, d.routingTable.ListPeers(), 2) diff --git a/dht_test.go b/dht_test.go index 5f028d540..ab02b8869 100644 --- a/dht_test.go +++ b/dht_test.go @@ -166,9 +166,7 @@ func connectNoSync(t *testing.T, ctx context.Context, a, b *IpfsDHT) { t.Fatal("peers setup incorrectly: no local address") } - a.peerstore.AddAddrs(idB, addrB, peerstore.TempAddrTTL) - pi := peer.AddrInfo{ID: idB} - if err := a.host.Connect(ctx, pi); err != nil { + if err := a.host.Connect(ctx, peer.AddrInfo{ID: idB, Addrs: addrB}); err != nil { t.Fatal(err) } } @@ -273,6 +271,7 @@ func TestValueGetSet(t *testing.T) { defer dhts[i].host.Close() } + t.Log("before connect") connect(t, ctx, dhts[0], dhts[1]) t.Log("adding value on: ", dhts[0].self) @@ -291,13 +290,13 @@ func TestValueGetSet(t *testing.T) { if err != nil { t.Fatal(err) } + t.Log("after get value") if string(val) != "world" { t.Fatalf("Expected 'world' got '%s'", string(val)) } - // late connect - + t.Log("late connect") connect(t, ctx, dhts[2], dhts[0]) connect(t, ctx, dhts[2], dhts[1]) @@ -320,6 +319,7 @@ func TestValueGetSet(t *testing.T) { t.Fatalf("Expected 'world' got '%s'", string(val)) } + t.Log("very late connect") for _, d := range dhts[:3] { connect(t, ctx, dhts[3], d) } @@ -610,25 +610,6 @@ func waitForWellFormedTables(t *testing.T, dhts []*IpfsDHT, minPeers, avgPeers i // test "well-formed-ness" (>= minPeers peers in every routing table) t.Helper() - checkTables := func() bool { - totalPeers := 0 - for _, dht := range dhts { - rtlen := dht.routingTable.Size() - totalPeers += rtlen - if minPeers > 0 && rtlen < minPeers { - // t.Logf("routing table for %s only has %d peers (should have >%d)", dht.self, rtlen, minPeers) - return false - } - } - actualAvgPeers := totalPeers / len(dhts) - t.Logf("avg rt size: %d", actualAvgPeers) - if avgPeers > 0 && actualAvgPeers < avgPeers { - t.Logf("avg rt size: %d < %d", actualAvgPeers, avgPeers) - return false - } - return true - } - timeoutA := time.After(timeout) for { select { @@ -636,7 +617,7 @@ func waitForWellFormedTables(t *testing.T, dhts []*IpfsDHT, minPeers, avgPeers i t.Errorf("failed to reach well-formed routing tables after %s", timeout) return case <-time.After(5 * time.Millisecond): - if checkTables() { + if checkForWellFormedTablesOnce(t, dhts, minPeers, avgPeers) { // succeeded return } @@ -644,6 +625,26 @@ func waitForWellFormedTables(t *testing.T, dhts []*IpfsDHT, minPeers, avgPeers i } } +func checkForWellFormedTablesOnce(t *testing.T, dhts []*IpfsDHT, minPeers, avgPeers int) bool { + t.Helper() + totalPeers := 0 + for _, dht := range dhts { + rtlen := dht.routingTable.Size() + totalPeers += rtlen + if minPeers > 0 && rtlen < minPeers { + t.Logf("routing table for %s only has %d peers (should have >%d)", dht.self, rtlen, minPeers) + return false + } + } + actualAvgPeers := totalPeers / len(dhts) + t.Logf("avg rt size: %d", actualAvgPeers) + if avgPeers > 0 && actualAvgPeers < avgPeers { + t.Logf("avg rt size: %d < %d", actualAvgPeers, avgPeers) + return false + } + return true +} + func printRoutingTables(dhts []*IpfsDHT) { // the routing tables should be full now. let's inspect them. fmt.Printf("checking routing table of %d\n", len(dhts)) @@ -679,24 +680,16 @@ func TestRefresh(t *testing.T) { <-time.After(100 * time.Millisecond) // bootstrap a few times until we get good tables. t.Logf("bootstrapping them so they find each other %d", nDHTs) - ctxT, cancelT := context.WithTimeout(ctx, 5*time.Second) - defer cancelT() - for ctxT.Err() == nil { - bootstrap(t, ctxT, dhts) + for { + bootstrap(t, ctx, dhts) - // wait a bit. - select { - case <-time.After(50 * time.Millisecond): - continue // being explicit - case <-ctxT.Done(): - return + if checkForWellFormedTablesOnce(t, dhts, 7, 10) { + break } - } - - waitForWellFormedTables(t, dhts, 7, 10, 10*time.Second) - cancelT() + time.Sleep(time.Microsecond * 50) + } if u.Debug { // the routing tables should be full now. let's inspect them. @@ -2121,7 +2114,7 @@ func TestBootstrapPeersFunc(t *testing.T) { bootstrapPeersB = []peer.AddrInfo{addrA} lock.Unlock() - dhtB.fixLowPeers(ctx) + dhtB.fixLowPeers() require.NotEqual(t, 0, len(dhtB.host.Network().Peers())) } diff --git a/fullrt/dht.go b/fullrt/dht.go index f1a26e70a..3b0cd3e94 100644 --- a/fullrt/dht.go +++ b/fullrt/dht.go @@ -151,7 +151,7 @@ func NewFullRT(h host.Host, protocolPrefix protocol.ID, options ...Option) (*Ful ctx, cancel := context.WithCancel(context.Background()) self := h.ID() - pm, err := providers.NewProviderManager(ctx, self, h.Peerstore(), dhtcfg.Datastore, fullrtcfg.pmOpts...) + pm, err := providers.NewProviderManager(self, h.Peerstore(), dhtcfg.Datastore, fullrtcfg.pmOpts...) if err != nil { cancel() return nil, err @@ -355,9 +355,8 @@ func (dht *FullRT) runCrawler(ctx context.Context) { func (dht *FullRT) Close() error { dht.cancel() - err := dht.ProviderManager.Process().Close() dht.wg.Wait() - return err + return dht.ProviderManager.Close() } func (dht *FullRT) Bootstrap(ctx context.Context) error { diff --git a/go.mod b/go.mod index a237e6026..fa233ef4a 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/ipfs/go-datastore v0.6.0 github.com/ipfs/go-detect-race v0.0.1 github.com/ipfs/go-log v1.0.5 - github.com/jbenet/goprocess v0.1.4 github.com/libp2p/go-libp2p v0.27.6 github.com/libp2p/go-libp2p-kbucket v0.6.3-0.20230615004129-e99cd472ed1e github.com/libp2p/go-libp2p-record v0.2.0 @@ -32,6 +31,7 @@ require ( go.opencensus.io v0.24.0 go.opentelemetry.io/otel v1.16.0 go.opentelemetry.io/otel/trace v1.16.0 + go.uber.org/multierr v1.11.0 go.uber.org/zap v1.24.0 gonum.org/v1/gonum v0.13.0 ) @@ -63,6 +63,7 @@ require ( github.com/ipld/go-ipld-prime v0.20.0 // indirect github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect + github.com/jbenet/goprocess v0.1.4 // indirect github.com/klauspost/compress v1.16.5 // indirect github.com/klauspost/cpuid/v2 v2.2.5 // indirect github.com/koron/go-ssdp v0.0.4 // indirect @@ -108,7 +109,6 @@ require ( go.uber.org/atomic v1.11.0 // indirect go.uber.org/dig v1.17.0 // indirect go.uber.org/fx v1.19.2 // indirect - go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.10.0 // indirect golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect golang.org/x/mod v0.10.0 // indirect diff --git a/handlers_test.go b/handlers_test.go index d829e38b1..35959df62 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -111,7 +111,7 @@ func BenchmarkHandleFindPeer(b *testing.B) { panic(err) } - d.peerFound(ctx, id) + d.peerFound(id) peers = append(peers, id) a, err := ma.NewMultiaddr(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 2000+i)) diff --git a/internal/net/message_manager.go b/internal/net/message_manager.go index f04dd0889..7908be2bd 100644 --- a/internal/net/message_manager.go +++ b/internal/net/message_manager.go @@ -43,7 +43,7 @@ type messageSenderImpl struct { protocols []protocol.ID } -func NewMessageSenderImpl(h host.Host, protos []protocol.ID) pb.MessageSender { +func NewMessageSenderImpl(h host.Host, protos []protocol.ID) pb.MessageSenderWithDisconnect { return &messageSenderImpl{ host: h, strmap: make(map[peer.ID]*peerMessageSender), diff --git a/pb/protocol_messenger.go b/pb/protocol_messenger.go index e175dde10..48aba3e35 100644 --- a/pb/protocol_messenger.go +++ b/pb/protocol_messenger.go @@ -45,6 +45,12 @@ func NewProtocolMessenger(msgSender MessageSender, opts ...ProtocolMessengerOpti return pm, nil } +type MessageSenderWithDisconnect interface { + MessageSender + + OnDisconnect(context.Context, peer.ID) +} + // MessageSender handles sending wire protocol messages to a given peer type MessageSender interface { // SendRequest sends a peer a message and waits for its response diff --git a/providers/providers_manager.go b/providers/providers_manager.go index f2a7ad17c..b7f1d7d90 100644 --- a/providers/providers_manager.go +++ b/providers/providers_manager.go @@ -4,7 +4,9 @@ import ( "context" "encoding/binary" "fmt" + "io" "strings" + "sync" "time" lru "github.com/hashicorp/golang-lru/simplelru" @@ -12,8 +14,6 @@ import ( "github.com/ipfs/go-datastore/autobatch" dsq "github.com/ipfs/go-datastore/query" logging "github.com/ipfs/go-log" - "github.com/jbenet/goprocess" - goprocessctx "github.com/jbenet/goprocess/context" "github.com/libp2p/go-libp2p-kad-dht/internal" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" @@ -45,6 +45,7 @@ var log = logging.Logger("providers") type ProviderStore interface { AddProvider(ctx context.Context, key []byte, prov peer.AddrInfo) error GetProviders(ctx context.Context, key []byte) ([]peer.AddrInfo, error) + io.Closer } // ProviderManager adds and pulls providers out of the datastore, @@ -59,9 +60,12 @@ type ProviderManager struct { newprovs chan *addProv getprovs chan *getProv - proc goprocess.Process cleanupInterval time.Duration + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup } var _ ProviderStore = (*ProviderManager)(nil) @@ -109,7 +113,7 @@ type getProv struct { } // NewProviderManager constructor -func NewProviderManager(ctx context.Context, local peer.ID, ps peerstore.Peerstore, dstore ds.Batching, opts ...Option) (*ProviderManager, error) { +func NewProviderManager(local peer.ID, ps peerstore.Peerstore, dstore ds.Batching, opts ...Option) (*ProviderManager, error) { pm := new(ProviderManager) pm.self = local pm.getprovs = make(chan *getProv) @@ -125,117 +129,121 @@ func NewProviderManager(ctx context.Context, local peer.ID, ps peerstore.Peersto if err := pm.applyOptions(opts...); err != nil { return nil, err } - pm.proc = goprocessctx.WithContext(ctx) - pm.proc.Go(func(proc goprocess.Process) { pm.run(ctx, proc) }) + pm.ctx, pm.cancel = context.WithCancel(context.Background()) + pm.run() return pm, nil } -// Process returns the ProviderManager process -func (pm *ProviderManager) Process() goprocess.Process { - return pm.proc -} +func (pm *ProviderManager) run() { + pm.wg.Add(1) + go func() { + defer pm.wg.Done() -func (pm *ProviderManager) run(ctx context.Context, proc goprocess.Process) { - var ( - gcQuery dsq.Results - gcQueryRes <-chan dsq.Result - gcSkip map[string]struct{} - gcTime time.Time - gcTimer = time.NewTimer(pm.cleanupInterval) - ) - - defer func() { - gcTimer.Stop() - if gcQuery != nil { - // don't really care if this fails. - _ = gcQuery.Close() - } - if err := pm.dstore.Flush(ctx); err != nil { - log.Error("failed to flush datastore: ", err) - } - }() + var gcQuery dsq.Results + gcTimer := time.NewTimer(pm.cleanupInterval) - for { - select { - case np := <-pm.newprovs: - err := pm.addProv(np.ctx, np.key, np.val) - if err != nil { - log.Error("error adding new providers: ", err) - continue + defer func() { + gcTimer.Stop() + if gcQuery != nil { + // don't really care if this fails. + _ = gcQuery.Close() } - if gcSkip != nil { - // we have an gc, tell it to skip this provider - // as we've updated it since the GC started. - gcSkip[mkProvKeyFor(np.key, np.val)] = struct{}{} + if err := pm.dstore.Flush(context.Background()); err != nil { + log.Error("failed to flush datastore: ", err) } - case gp := <-pm.getprovs: - provs, err := pm.getProvidersForKey(gp.ctx, gp.key) - if err != nil && err != ds.ErrNotFound { - log.Error("error reading providers: ", err) - } - - // set the cap so the user can't append to this. - gp.resp <- provs[0:len(provs):len(provs)] - case res, ok := <-gcQueryRes: - if !ok { - if err := gcQuery.Close(); err != nil { - log.Error("failed to close provider GC query: ", err) + }() + + var gcQueryRes <-chan dsq.Result + var gcSkip map[string]struct{} + var gcTime time.Time + for { + select { + case np := <-pm.newprovs: + err := pm.addProv(np.ctx, np.key, np.val) + if err != nil { + log.Error("error adding new providers: ", err) + continue + } + if gcSkip != nil { + // we have an gc, tell it to skip this provider + // as we've updated it since the GC started. + gcSkip[mkProvKeyFor(np.key, np.val)] = struct{}{} + } + case gp := <-pm.getprovs: + provs, err := pm.getProvidersForKey(gp.ctx, gp.key) + if err != nil && err != ds.ErrNotFound { + log.Error("error reading providers: ", err) } - gcTimer.Reset(pm.cleanupInterval) - // cleanup GC round - gcQueryRes = nil - gcSkip = nil - gcQuery = nil - continue - } - if res.Error != nil { - log.Error("got error from GC query: ", res.Error) - continue - } - if _, ok := gcSkip[res.Key]; ok { - // We've updated this record since starting the - // GC round, skip it. - continue - } + // set the cap so the user can't append to this. + gp.resp <- provs[0:len(provs):len(provs)] + case res, ok := <-gcQueryRes: + if !ok { + if err := gcQuery.Close(); err != nil { + log.Error("failed to close provider GC query: ", err) + } + gcTimer.Reset(pm.cleanupInterval) + + // cleanup GC round + gcQueryRes = nil + gcSkip = nil + gcQuery = nil + continue + } + if res.Error != nil { + log.Error("got error from GC query: ", res.Error) + continue + } + if _, ok := gcSkip[res.Key]; ok { + // We've updated this record since starting the + // GC round, skip it. + continue + } - // check expiration time - t, err := readTimeValue(res.Value) - switch { - case err != nil: - // couldn't parse the time - log.Error("parsing providers record from disk: ", err) - fallthrough - case gcTime.Sub(t) > ProvideValidity: - // or expired - err = pm.dstore.Delete(ctx, ds.RawKey(res.Key)) - if err != nil && err != ds.ErrNotFound { - log.Error("failed to remove provider record from disk: ", err) + // check expiration time + t, err := readTimeValue(res.Value) + switch { + case err != nil: + // couldn't parse the time + log.Error("parsing providers record from disk: ", err) + fallthrough + case gcTime.Sub(t) > ProvideValidity: + // or expired + err = pm.dstore.Delete(pm.ctx, ds.RawKey(res.Key)) + if err != nil && err != ds.ErrNotFound { + log.Error("failed to remove provider record from disk: ", err) + } } - } - case gcTime = <-gcTimer.C: - // You know the wonderful thing about caches? You can - // drop them. - // - // Much faster than GCing. - pm.cache.Purge() - - // Now, kick off a GC of the datastore. - q, err := pm.dstore.Query(ctx, dsq.Query{ - Prefix: ProvidersKeyPrefix, - }) - if err != nil { - log.Error("provider record GC query failed: ", err) - continue + case gcTime = <-gcTimer.C: + // You know the wonderful thing about caches? You can + // drop them. + // + // Much faster than GCing. + pm.cache.Purge() + + // Now, kick off a GC of the datastore. + q, err := pm.dstore.Query(pm.ctx, dsq.Query{ + Prefix: ProvidersKeyPrefix, + }) + if err != nil { + log.Error("provider record GC query failed: ", err) + continue + } + gcQuery = q + gcQueryRes = q.Next() + gcSkip = make(map[string]struct{}) + case <-pm.ctx.Done(): + return } - gcQuery = q - gcQueryRes = q.Next() - gcSkip = make(map[string]struct{}) - case <-proc.Closing(): - return } - } + }() +} + +func (pm *ProviderManager) Close() error { + pm.cancel() + pm.wg.Wait() + return nil } // AddProvider adds a provider diff --git a/providers/providers_manager_test.go b/providers/providers_manager_test.go index ba238a59e..e830929ef 100644 --- a/providers/providers_manager_test.go +++ b/providers/providers_manager_test.go @@ -31,7 +31,7 @@ func TestProviderManager(t *testing.T) { if err != nil { t.Fatal(err) } - p, err := NewProviderManager(ctx, mid, ps, dssync.MutexWrap(ds.NewMapDatastore())) + p, err := NewProviderManager(mid, ps, dssync.MutexWrap(ds.NewMapDatastore())) if err != nil { t.Fatal(err) } @@ -60,7 +60,7 @@ func TestProviderManager(t *testing.T) { t.Fatalf("Should have got 3 providers, got %d", len(resp)) } - p.proc.Close() + p.Close() } func TestProvidersDatastore(t *testing.T) { @@ -77,11 +77,11 @@ func TestProvidersDatastore(t *testing.T) { t.Fatal(err) } - p, err := NewProviderManager(ctx, mid, ps, dssync.MutexWrap(ds.NewMapDatastore())) + p, err := NewProviderManager(mid, ps, dssync.MutexWrap(ds.NewMapDatastore())) if err != nil { t.Fatal(err) } - defer p.proc.Close() + defer p.Close() friend := peer.ID("friend") var mhs []mh.Multihash @@ -166,7 +166,7 @@ func TestProvidesExpire(t *testing.T) { if err != nil { t.Fatal(err) } - p, err := NewProviderManager(ctx, mid, ps, ds) + p, err := NewProviderManager(mid, ps, ds) if err != nil { t.Fatal(err) } @@ -216,7 +216,7 @@ func TestProvidesExpire(t *testing.T) { time.Sleep(time.Second / 2) // Stop to prevent data races - p.Process().Close() + p.Close() if p.cache.Len() != 0 { t.Fatal("providers map not cleaned up") @@ -278,11 +278,11 @@ func TestLargeProvidersSet(t *testing.T) { t.Fatal(err) } - p, err := NewProviderManager(ctx, mid, ps, dstore) + p, err := NewProviderManager(mid, ps, dstore) if err != nil { t.Fatal(err) } - defer p.proc.Close() + defer p.Close() var mhs []mh.Multihash for i := 0; i < 1000; i++ { @@ -318,7 +318,7 @@ func TestUponCacheMissProvidersAreReadFromDatastore(t *testing.T) { t.Fatal(err) } - pm, err := NewProviderManager(ctx, p1, ps, dssync.MutexWrap(ds.NewMapDatastore())) + pm, err := NewProviderManager(p1, ps, dssync.MutexWrap(ds.NewMapDatastore())) if err != nil { t.Fatal(err) } @@ -347,7 +347,7 @@ func TestWriteUpdatesCache(t *testing.T) { t.Fatal(err) } - pm, err := NewProviderManager(ctx, p1, ps, dssync.MutexWrap(ds.NewMapDatastore())) + pm, err := NewProviderManager(p1, ps, dssync.MutexWrap(ds.NewMapDatastore())) if err != nil { t.Fatal(err) } diff --git a/query.go b/query.go index c8b07d650..524269aec 100644 --- a/query.go +++ b/query.go @@ -446,7 +446,7 @@ func (q *query) queryPeer(ctx context.Context, ch chan<- *queryUpdate, p peer.ID queryDuration := time.Since(startQuery) // query successful, try to add to RT - q.dht.validPeerFound(q.dht.ctx, p) + q.dht.validPeerFound(p) // process new peers saw := []peer.ID{} diff --git a/rtrefresh/rt_refresh_manager.go b/rtrefresh/rt_refresh_manager.go index d08983702..6de69b026 100644 --- a/rtrefresh/rt_refresh_manager.go +++ b/rtrefresh/rt_refresh_manager.go @@ -31,10 +31,9 @@ type triggerRefreshReq struct { } type RtRefreshManager struct { - ctx context.Context - cancel context.CancelFunc - refcount sync.WaitGroup - closeOnce sync.Once + ctx context.Context + cancel context.CancelFunc + refcount sync.WaitGroup // peerId of this DHT peer i.e. self peerId. h host.Host @@ -89,17 +88,14 @@ func NewRtRefreshManager(h host.Host, rt *kbucket.RoutingTable, autoRefresh bool }, nil } -func (r *RtRefreshManager) Start() error { +func (r *RtRefreshManager) Start() { r.refcount.Add(1) go r.loop() - return nil } func (r *RtRefreshManager) Close() error { - r.closeOnce.Do(func() { - r.cancel() - r.refcount.Wait() - }) + r.cancel() + r.refcount.Wait() return nil } @@ -117,6 +113,7 @@ func (r *RtRefreshManager) Refresh(force bool) <-chan error { case r.triggerRefresh <- &triggerRefreshReq{respCh: resp, forceCplRefresh: force}: case <-r.ctx.Done(): resp <- r.ctx.Err() + close(resp) } }() diff --git a/subscriber_notifee.go b/subscriber_notifee.go index 23c21ffb9..759db76c6 100644 --- a/subscriber_notifee.go +++ b/subscriber_notifee.go @@ -1,27 +1,15 @@ package dht import ( - "context" "fmt" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/p2p/host/eventbus" - - "github.com/jbenet/goprocess" - ma "github.com/multiformats/go-multiaddr" ) -// subscriberNotifee implements network.Notifee and also manages the subscriber to the event bus. We consume peer -// identification events to trigger inclusion in the routing table, and we consume Disconnected events to eject peers -// from it. -type subscriberNotifee struct { - dht *IpfsDHT - subs event.Subscription -} - -func newSubscriberNotifiee(dht *IpfsDHT) (*subscriberNotifee, error) { +func (dht *IpfsDHT) startNetworkSubscriber() error { bufSize := eventbus.BufSize(256) evts := []interface{}{ @@ -35,6 +23,9 @@ func newSubscriberNotifiee(dht *IpfsDHT) (*subscriberNotifee, error) { // register for event bus notifications for when our local address/addresses change so we can // advertise those to the network new(event.EvtLocalAddressesUpdated), + + // we want to know when we are disconnecting from other peers. + new(event.EvtPeerConnectednessChanged), } // register for event bus local routability changes in order to trigger switching between client and server modes @@ -45,61 +36,57 @@ func newSubscriberNotifiee(dht *IpfsDHT) (*subscriberNotifee, error) { subs, err := dht.host.EventBus().Subscribe(evts, bufSize) if err != nil { - return nil, fmt.Errorf("dht could not subscribe to eventbus events; err: %s", err) + return fmt.Errorf("dht could not subscribe to eventbus events: %w", err) } - nn := &subscriberNotifee{ - dht: dht, - subs: subs, - } - - // register for network notifications - dht.host.Network().Notify(nn) - - return nn, nil -} - -func (nn *subscriberNotifee) subscribe(proc goprocess.Process) { - dht := nn.dht - defer dht.host.Network().StopNotify(nn) - defer nn.subs.Close() - - for { - select { - case e, more := <-nn.subs.Out(): - if !more { - return - } + dht.wg.Add(1) + go func() { + defer dht.wg.Done() + defer subs.Close() - switch evt := e.(type) { - case event.EvtLocalAddressesUpdated: - // when our address changes, we should proactively tell our closest peers about it so - // we become discoverable quickly. The Identify protocol will push a signed peer record - // with our new address to all peers we are connected to. However, we might not necessarily be connected - // to our closet peers & so in the true spirit of Zen, searching for ourself in the network really is the best way - // to to forge connections with those matter. - if dht.autoRefresh || dht.testAddressUpdateProcessing { - dht.rtRefreshManager.RefreshNoWait() + for { + select { + case e, more := <-subs.Out(): + if !more { + return } - case event.EvtPeerProtocolsUpdated: - handlePeerChangeEvent(dht, evt.Peer) - case event.EvtPeerIdentificationCompleted: - handlePeerChangeEvent(dht, evt.Peer) - case event.EvtLocalReachabilityChanged: - if dht.auto == ModeAuto || dht.auto == ModeAutoServer { - handleLocalReachabilityChangedEvent(dht, evt) - } else { - // something has gone really wrong if we get an event we did not subscribe to - logger.Errorf("received LocalReachabilityChanged event that was not subscribed to") + + switch evt := e.(type) { + case event.EvtLocalAddressesUpdated: + // when our address changes, we should proactively tell our closest peers about it so + // we become discoverable quickly. The Identify protocol will push a signed peer record + // with our new address to all peers we are connected to. However, we might not necessarily be connected + // to our closet peers & so in the true spirit of Zen, searching for ourself in the network really is the best way + // to to forge connections with those matter. + if dht.autoRefresh || dht.testAddressUpdateProcessing { + dht.rtRefreshManager.RefreshNoWait() + } + case event.EvtPeerProtocolsUpdated: + handlePeerChangeEvent(dht, evt.Peer) + case event.EvtPeerIdentificationCompleted: + handlePeerChangeEvent(dht, evt.Peer) + case event.EvtPeerConnectednessChanged: + if evt.Connectedness != network.Connected { + dht.msgSender.OnDisconnect(dht.ctx, evt.Peer) + } + case event.EvtLocalReachabilityChanged: + if dht.auto == ModeAuto || dht.auto == ModeAutoServer { + handleLocalReachabilityChangedEvent(dht, evt) + } else { + // something has gone really wrong if we get an event we did not subscribe to + logger.Errorf("received LocalReachabilityChanged event that was not subscribed to") + } + default: + // something has gone really wrong if we get an event for another type + logger.Errorf("got wrong type from subscription: %T", e) } - default: - // something has gone really wrong if we get an event for another type - logger.Errorf("got wrong type from subscription: %T", e) + case <-dht.ctx.Done(): + return } - case <-proc.Closing(): - return } - } + }() + + return nil } func handlePeerChangeEvent(dht *IpfsDHT, p peer.ID) { @@ -108,7 +95,7 @@ func handlePeerChangeEvent(dht *IpfsDHT, p peer.ID) { logger.Errorf("could not check peerstore for protocol support: err: %s", err) return } else if valid { - dht.peerFound(dht.ctx, p) + dht.peerFound(p) dht.fixRTIfNeeded() } else { dht.peerStoppedDHT(p) @@ -153,41 +140,3 @@ func (dht *IpfsDHT) validRTPeer(p peer.ID) (bool, error) { return dht.routingTablePeerFilter == nil || dht.routingTablePeerFilter(dht, p), nil } - -type disconnector interface { - OnDisconnect(ctx context.Context, p peer.ID) -} - -func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { - dht := nn.dht - - ms, ok := dht.msgSender.(disconnector) - if !ok { - return - } - - select { - case <-dht.Process().Closing(): - return - default: - } - - p := v.RemotePeer() - - // Lock and check to see if we're still connected. We lock to make sure - // we don't concurrently process a connect event. - dht.plk.Lock() - defer dht.plk.Unlock() - if dht.host.Network().Connectedness(p) == network.Connected { - // We're still connected. - return - } - - ms.OnDisconnect(dht.Context(), p) -} - -func (nn *subscriberNotifee) Connected(network.Network, network.Conn) {} -func (nn *subscriberNotifee) OpenedStream(network.Network, network.Stream) {} -func (nn *subscriberNotifee) ClosedStream(network.Network, network.Stream) {} -func (nn *subscriberNotifee) Listen(network.Network, ma.Multiaddr) {} -func (nn *subscriberNotifee) ListenClose(network.Network, ma.Multiaddr) {}