Skip to content

Commit

Permalink
Merge pull request #596 from Semisol/bugfix/fix-tests
Browse files Browse the repository at this point in the history
Fix failing tests and test network
  • Loading branch information
lthibault authored Jan 6, 2025
2 parents 776eaca + cb8ddef commit 567b3ac
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 56 deletions.
86 changes: 37 additions & 49 deletions rpc/internal/testnetwork/testnetwork.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ func (e edge) Flip() edge {
}
}

// This test network uses the same set of options for all
// participants. The rpc.Options instance can be cloned
// without issue.
type network struct {
myID PeerID
global *Joiner
myID PeerID
options rpc.Options
global *Joiner
}

// A Joiner is a global view of a test network, which can be joined by a
Expand All @@ -51,12 +55,13 @@ func NewJoiner() *Joiner {
}
}

func (j *Joiner) Join() rpc.Network {
func (j *Joiner) Join(opts *rpc.Options) rpc.Network {
j.mu.Lock()
defer j.mu.Unlock()
ret := network{
myID: j.nextID,
global: j,
myID: j.nextID,
global: j,
options: *opts,
}
j.nextID++
return ret
Expand All @@ -72,13 +77,11 @@ func (j *Joiner) getAcceptQueue(id PeerID) spsc.Queue[PeerID] {
}

func (n network) LocalID() rpc.PeerID {
return rpc.PeerID{n.myID}
return rpc.PeerID{Value: n.myID}
}

func (n network) Dial(dst rpc.PeerID, opts *rpc.Options) (*rpc.Conn, error) {
if opts == nil {
opts = &rpc.Options{}
}
func (n network) Dial(dst rpc.PeerID) (*rpc.Conn, error) {
opts := n.options
opts.Network = n
opts.RemotePeerID = dst
dstID := dst.Value.(PeerID)
Expand All @@ -101,7 +104,7 @@ func (n network) Dial(dst rpc.PeerID, opts *rpc.Options) (*rpc.Conn, error) {

}
if ent.Conn == nil {
ent.Conn = rpc.NewConn(ent.Transport, opts)
ent.Conn = rpc.NewConn(ent.Transport, &opts)
} else {
// There's already a connection, so we're not going to use this, but
// we own it. So drop it:
Expand All @@ -110,30 +113,32 @@ func (n network) Dial(dst rpc.PeerID, opts *rpc.Options) (*rpc.Conn, error) {
return ent.Conn, nil
}

func (n network) Accept(ctx context.Context, opts *rpc.Options) (*rpc.Conn, error) {
func (n network) Serve(ctx context.Context) error {
n.global.mu.Lock()
q := n.global.getAcceptQueue(n.myID)
n.global.mu.Unlock()

incoming, err := q.Recv(ctx)
if err != nil {
return nil, err
}
opts.Network = n
opts.RemotePeerID = rpc.PeerID{incoming}
n.global.mu.Lock()
defer n.global.mu.Unlock()
edge := edge{
From: n.myID,
To: incoming,
for {
incoming, err := q.Recv(ctx)
if err != nil {
return err
}
opts := n.options
opts.Network = n
opts.RemotePeerID = rpc.PeerID{incoming}
n.global.mu.Lock()
defer n.global.mu.Unlock()
edge := edge{
From: n.myID,
To: incoming,
}
ent := n.global.connections[edge]
if ent.Conn == nil {
ent.Conn = rpc.NewConn(ent.Transport, &opts)
} else {
opts.BootstrapClient.Release()
}
}
ent := n.global.connections[edge]
if ent.Conn == nil {
ent.Conn = rpc.NewConn(ent.Transport, opts)
} else {
opts.BootstrapClient.Release()
}
return ent.Conn, nil
}

func (n network) Introduce(provider, recipient *rpc.Conn) (rpc.IntroductionInfo, error) {
Expand All @@ -157,24 +162,7 @@ func (n network) Introduce(provider, recipient *rpc.Conn) (rpc.IntroductionInfo,
sendToRecipient.SetNonce(nonce)
sendToProvider.SetPeerId(uint64(recipientPeer.Value.(PeerID)))
sendToProvider.SetNonce(nonce)
ret.SendToRecipient = rpc.ThirdPartyCapID(sendToRecipient.ToPtr())
ret.SendToProvider = rpc.RecipientID(sendToProvider.ToPtr())
ret.SendToRecipient = rpc.ThirdPartyToContact(sendToRecipient.ToPtr())
ret.SendToProvider = rpc.ThirdPartyToAwait(sendToProvider.ToPtr())
return ret, nil
}
func (n network) DialIntroduced(capID rpc.ThirdPartyCapID, introducedBy *rpc.Conn) (*rpc.Conn, rpc.ProvisionID, error) {
cid := PeerAndNonce(capnp.Ptr(capID).Struct())

_, seg := capnp.NewSingleSegmentMessage(nil)
pid, err := NewPeerAndNonce(seg)
if err != nil {
return nil, rpc.ProvisionID{}, err
}
pid.SetPeerId(uint64(introducedBy.RemotePeerID().Value.(PeerID)))
pid.SetNonce(cid.Nonce())

conn, err := n.Dial(rpc.PeerID{PeerID(cid.PeerId())}, nil)
return conn, rpc.ProvisionID(pid.ToPtr()), err
}
func (n network) AcceptIntroduced(recipientID rpc.RecipientID, introducedBy *rpc.Conn) (*rpc.Conn, error) {
panic("TODO")
}
2 changes: 1 addition & 1 deletion rpc/level1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1053,5 +1053,5 @@ type rpcDisembargoContext struct {
Which rpccp.Disembargo_context_Which
SenderLoopback uint32
ReceiverLoopback uint32
Provide uint32
Accept capnp.Ptr
}
12 changes: 6 additions & 6 deletions schemas/schemas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ func TestDefaultFind(t *testing.T) {
if s := schemas.Find(0xdeadbeef); s != nil {
t.Errorf("schemas.Find(0xdeadbeef) = %d-byte slice; want nil", len(s))
}
s := schemas.Find(gocp.Package_)
s := schemas.Find(gocp.Package)
if s == nil {
t.Fatalf("schemas.Find(%#x) = nil", gocp.Package_)
t.Fatalf("schemas.Find(%#x) = nil", gocp.Package)
}
msg, err := capnp.Unmarshal(s)
if err != nil {
t.Fatalf("capnp.Unmarshal(schemas.Find(%#x)) error: %v", gocp.Package_, err)
t.Fatalf("capnp.Unmarshal(schemas.Find(%#x)) error: %v", gocp.Package, err)
}
req, err := schema.ReadRootCodeGeneratorRequest(msg)
if err != nil {
Expand All @@ -32,15 +32,15 @@ func TestDefaultFind(t *testing.T) {
}
for i := 0; i < nodes.Len(); i++ {
n := nodes.At(i)
if n.Id() == gocp.Package_ {
if n.Id() == gocp.Package {
// Found
if n.Which() != schema.Node_Which_annotation {
t.Errorf("found node %#x which = %v; want annotation", gocp.Package_, n.Which())
t.Errorf("found node %#x which = %v; want annotation", gocp.Package, n.Which())
}
return
}
}
t.Fatalf("could not find node %#x in registry", gocp.Package_)
t.Fatalf("could not find node %#x in registry", gocp.Package)
}

func TestNotFound(t *testing.T) {
Expand Down

0 comments on commit 567b3ac

Please sign in to comment.