diff --git a/cmd/strawberry/main.go b/cmd/strawberry/main.go index 1143c13..93702fc 100644 --- a/cmd/strawberry/main.go +++ b/cmd/strawberry/main.go @@ -6,116 +6,44 @@ import ( "flag" "fmt" "log" - "time" + "net" - "github.com/eigerco/strawberry/pkg/network/cert" - "github.com/eigerco/strawberry/pkg/network/handlers" "github.com/eigerco/strawberry/pkg/network/peer" - "github.com/eigerco/strawberry/pkg/network/protocol" - "github.com/eigerco/strawberry/pkg/network/transport" ) // main starts a blockchain node. -// -// To run the first node (listener): -// -// go run main.go -addr localhost:9000 -// -// To run a second node that connects to the first node: -// -// go run main.go -addr localhost:9001 -connect localhost:9000 -// -// - The first node listens on port 9000. -// - The second node listens on port 9001 and connects to the first node's address (localhost:9000). +// go run main.go -addr localhost:9000 func main() { - listenAddr := flag.String("addr", "", "Listen address (e.g., 0.0.0.0:9000)") - connectTo := flag.String("connect", "", "Address to connect to (optional)") + ctx := context.Background() + listenAddr := flag.String("addr", "", "Listen address") flag.Parse() if *listenAddr == "" { log.Fatal("listen address is required") } - // Generate node keys pub, priv, err := ed25519.GenerateKey(nil) if err != nil { - log.Fatalf("Failed to generate keys: %v", err) + panic(err) } - - // Create certificate - certGen := cert.NewGenerator(cert.Config{ - PublicKey: pub, - PrivateKey: priv, - CertValidityPeriod: 24 * time.Hour, - }) - tlsCert, err := certGen.GenerateCertificate() - if err != nil { - log.Fatalf("Failed to generate certificate: %v", err) + keys := peer.ValidatorKeys{ + EdPrv: priv, + EdPub: pub, } - // Create protocol manager - protoConfig := protocol.Config{ - ChainHash: "12345678", // Example chain hash - IsBuilder: false, - MaxBuilderSlots: 20, - } - protoManager, err := protocol.NewManager(protoConfig) + address, err := net.ResolveUDPAddr("", *listenAddr) if err != nil { - log.Fatalf("Failed to create protocol manager: %v", err) + panic(err) } - - // Register protocol handlers - protoManager.Registry.RegisterHandler(protocol.StreamKindBlockRequest, handlers.NewBlockRequestHandler()) - - // Create transport with minimal config - transportConfig := transport.Config{ - PublicKey: pub, - PrivateKey: priv, - TLSCert: tlsCert, - ListenAddr: *listenAddr, - CertValidator: cert.NewValidator(), - Handler: protoManager, // Protocol manager implements ConnectionHandler - } - - tr, err := transport.NewTransport(transportConfig) + fmt.Printf("listening on: %v\n", address) + node, err := peer.NewNode(ctx, address, keys) if err != nil { - log.Fatalf("Failed to create transport: %v", err) + panic(err) } - - if err := tr.Start(); err != nil { - log.Fatalf("Failed to start transport: %v", err) - } - defer func() { - if err := tr.Stop(); err != nil { - fmt.Printf("Failed to stop transport: %v\n", err) - } - }() - - log.Printf("Node listening on %s", *listenAddr) - - // If we have an address to connect to, make a request - if *connectTo != "" { - log.Printf("Connecting to peer at %s", *connectTo) - - conn, err := tr.Connect(*connectTo) - if err != nil { - log.Fatalf("Failed to connect to peer: %v", err) - } - - // Create a dummy block hash for the request - hash := [32]byte{1, 2, 3, 4} // Example hash - - // Create peer with protocol connection - p := peer.NewPeer(conn, protoManager) - ctx := context.Background() - blocks, err := p.RequestBlocks(ctx, hash, true) - if err != nil { - log.Fatalf("Failed to request blocks: %v", err) - } - fmt.Printf("blocks: %v\n", blocks) - log.Printf("Block request completed") + err = node.Start() + if err != nil { + panic(err) } - // Keep the node running select {} } diff --git a/go.mod b/go.mod index 25a5bda..cf70601 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.22.5 require ( github.com/cockroachdb/pebble v1.1.3 github.com/ebitengine/purego v0.8.1 - github.com/golang/mock v1.6.0 github.com/pkg/errors v0.9.1 github.com/quic-go/quic-go v0.48.2 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index e0c2d09..79d75c4 100644 --- a/go.sum +++ b/go.sum @@ -33,8 +33,6 @@ github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1v github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= -github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb h1:PBC98N2aIaM3XXiurYmW7fx4GZkL8feAMVq7nEjURHk= github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= @@ -79,7 +77,6 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= @@ -93,31 +90,24 @@ golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0 golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= @@ -128,7 +118,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/network/network_test.go b/pkg/network/network_test.go index 75fbb7d..66b071c 100644 --- a/pkg/network/network_test.go +++ b/pkg/network/network_test.go @@ -1,330 +1 @@ package network_test - -import ( - "context" - "crypto/ed25519" - "net" - "sync" - "testing" - "time" - - "github.com/eigerco/strawberry/pkg/network/cert" - "github.com/eigerco/strawberry/pkg/network/handlers" - "github.com/eigerco/strawberry/pkg/network/peer" - "github.com/eigerco/strawberry/pkg/network/protocol" - "github.com/eigerco/strawberry/pkg/network/transport" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// testNode represents a node instance for testing -type testNode struct { - transport *transport.Transport - protoManager *protocol.Manager - addr string - pubKey ed25519.PublicKey - privKey ed25519.PrivateKey -} - -// setupTestNode creates a new test node with all necessary components -func setupTestNode(t *testing.T) *testNode { - // Find available port - listener, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) - addr := listener.Addr().String() - listener.Close() - - // Generate keys - pub, priv, err := ed25519.GenerateKey(nil) - require.NoError(t, err) - - // Create certificate - certGen := cert.NewGenerator(cert.Config{ - PublicKey: pub, - PrivateKey: priv, - CertValidityPeriod: 24 * time.Hour, - }) - tlsCert, err := certGen.GenerateCertificate() - require.NoError(t, err) - - // Create protocol manager - protoConfig := protocol.Config{ - ChainHash: "12345678", - IsBuilder: false, - } - protoManager, err := protocol.NewManager(protoConfig) - require.NoError(t, err) - - // Register handlers - blockHandler := handlers.NewBlockRequestHandler() - protoManager.Registry.RegisterHandler(protocol.StreamKindBlockRequest, blockHandler) - - // Create transport - transportConfig := transport.Config{ - PublicKey: pub, - PrivateKey: priv, - TLSCert: tlsCert, - ListenAddr: addr, - CertValidator: cert.NewValidator(), - Handler: protoManager, - } - - tr, err := transport.NewTransport(transportConfig) - require.NoError(t, err) - - return &testNode{ - transport: tr, - protoManager: protoManager, - addr: addr, - pubKey: pub, - privKey: priv, - } -} - -// setupTestPair creates and connects two test nodes -func setupTestPair(t *testing.T) (*testNode, *testNode, *peer.Peer) { - node1 := setupTestNode(t) - node2 := setupTestNode(t) - - require.NoError(t, node1.transport.Start()) - require.NoError(t, node2.transport.Start()) - - conn, err := node2.transport.Connect(node1.addr) - require.NoError(t, err) - - p := peer.NewPeer(conn, node2.protoManager) - return node1, node2, p -} - -// Helper function to safely stop transports -func cleanupNodes(t *testing.T, nodes ...*testNode) { - for _, node := range nodes { - if err := node.transport.Stop(); err != nil { - t.Errorf("failed to stop transport: %v", err) - } - } -} - -// TestBasicBlockRequest tests a simple block request -func TestBasicBlockRequest(t *testing.T) { - node1, node2, p := setupTestPair(t) - defer cleanupNodes(t, node1, node2) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - response, err := p.RequestBlocks(ctx, [32]byte{1, 2, 3, 4}, true) - require.NoError(t, err) - assert.Equal(t, "test block response", string(response), "unexpected response content") -} - -// TestConcurrentBlockRequests tests handling multiple concurrent requests -func TestConcurrentBlockRequests(t *testing.T) { - node1, node2, p := setupTestPair(t) - defer cleanupNodes(t, node1, node2) - - var wg sync.WaitGroup - numRequests := 5 - type result struct { - response []byte - err error - } - results := make(chan result, numRequests) - - for i := 0; i < numRequests; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - response, err := p.RequestBlocks(ctx, [32]byte{byte(i)}, true) - results <- result{response, err} - }(i) - } - - wg.Wait() - close(results) - - successCount := 0 - for res := range results { - if assert.NoError(t, res.err) { - assert.Equal(t, "test block response", string(res.response)) - successCount++ - } - } - assert.Equal(t, numRequests, successCount, "all requests should succeed") -} - -// TestRequestTimeout tests proper handling of timeouts -func TestRequestTimeout(t *testing.T) { - node1, node2, p := setupTestPair(t) - defer cleanupNodes(t, node1, node2) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) - defer cancel() - - response, err := p.RequestBlocks(ctx, [32]byte{9, 9, 9, 9}, true) - assert.Error(t, err) - assert.Contains(t, err.Error(), context.DeadlineExceeded.Error()) - assert.Nil(t, response) -} - -// TestConnectionClosure tests behavior when connection is closed -func TestConnectionClosure(t *testing.T) { - node1, node2, p := setupTestPair(t) - defer cleanupNodes(t, node1, node2) - - // Close node1's transport - require.NoError(t, node1.transport.Stop()) - // Wait a bit to ensure connection is fully closed - time.Sleep(1 * time.Second) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - response, err := p.RequestBlocks(ctx, [32]byte{}, true) - assert.Error(t, err) - assert.Nil(t, response) -} - -// TestNetworkPartition tests behavior during network issues -func TestNetworkPartition(t *testing.T) { - node1, node2, p := setupTestPair(t) - defer cleanupNodes(t, node1, node2) - - // Simulate network partition by stopping node1 - require.NoError(t, node1.transport.Stop()) - - // Start node1 again - require.NoError(t, node1.transport.Start()) - - // Try request after reconnection - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - _, err := p.RequestBlocks(ctx, [32]byte{1, 2, 3, 4}, true) - assert.Error(t, err) // Should fail due to broken connection -} - -// TestReconnection tests reconnection behavior -func TestServerNodeRestart(t *testing.T) { - node1, node2, p := setupTestPair(t) - defer cleanupNodes(t, node1, node2) - - // Make successful request - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := p.RequestBlocks(ctx, [32]byte{1}, true) - cancel() - require.NoError(t, err) - - // Close and restart node1's transport - require.NoError(t, node1.transport.Stop()) - require.NoError(t, node1.transport.Start()) - - conn, err := node2.transport.Connect(node1.addr) - require.NoError(t, err) - - // Create new peer with new connection - p = peer.NewPeer(conn, node2.protoManager) - - // Try request with longer timeout - ctx, cancel = context.WithTimeout(context.Background(), time.Second) - defer cancel() - - res, err := p.RequestBlocks(ctx, [32]byte{1}, true) - require.NoError(t, err) - assert.Equal(t, "test block response", string(res)) -} - -func TestClientNodeRestart(t *testing.T) { - node1, node2, p := setupTestPair(t) - defer cleanupNodes(t, node1, node2) - - // Make initial successful request - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - response1, err := p.RequestBlocks(ctx, [32]byte{1, 2, 3, 4}, true) - cancel() - require.NoError(t, err) - assert.Equal(t, "test block response", string(response1)) - - // Close connection from node2 side - require.NoError(t, node2.transport.Stop()) - - // Restart node2 - require.NoError(t, node2.transport.Start()) - - // Create new connection - conn, err := node2.transport.Connect(node1.addr) - require.NoError(t, err) - - // Create new peer with new connection - newPeer := peer.NewPeer(conn, node2.protoManager) - - // Try request with new peer - ctx, cancel = context.WithTimeout(context.Background(), time.Second) - response2, err := newPeer.RequestBlocks(ctx, [32]byte{1, 2, 3, 4}, true) - cancel() - require.NoError(t, err) - assert.Equal(t, "test block response", string(response2)) -} - -func TestConnectWithWrongChainHash(t *testing.T) { - // Setup node1 - node1 := setupTestNode(t) - defer cleanupNodes(t, node1) - - // Generate a unique address for node2 - listener, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) - node2Addr := listener.Addr().String() - listener.Close() - - // Generate keys for node2 - pub, priv, err := ed25519.GenerateKey(nil) - require.NoError(t, err) - - // Create certificate for node2 - certGen := cert.NewGenerator(cert.Config{ - PublicKey: pub, - PrivateKey: priv, - CertValidityPeriod: 24 * time.Hour, - }) - tlsCert, err := certGen.GenerateCertificate() - require.NoError(t, err) - - // Create a protocol manager with a wrong chain hash for node2 - protoConfig := protocol.Config{ - ChainHash: "12345679", // Mismatched chain hash - IsBuilder: false, - } - protoManager, err := protocol.NewManager(protoConfig) - require.NoError(t, err) - - // Create transport for node2 - transportConfig := transport.Config{ - PublicKey: pub, - PrivateKey: priv, - TLSCert: tlsCert, - ListenAddr: node2Addr, // Use the new address for node2 - CertValidator: cert.NewValidator(), - Handler: protoManager, - } - - node2Transport, err := transport.NewTransport(transportConfig) - require.NoError(t, err) - - // Start node1 and node2 - require.NoError(t, node1.transport.Start()) - require.NoError(t, node2Transport.Start()) - defer func() { - if err := node2Transport.Stop(); err != nil { - t.Errorf("failed to stop node2 transport: %v", err) - } - }() - - // Attempt to connect node2 to node1 - _, err = node2Transport.Connect(node1.addr) - assert.Error(t, err, "connection should fail due to ALPN mismatch") - assert.Contains(t, err.Error(), "no application protocol", "error should indicate ALPN failure") -} diff --git a/pkg/network/peer/node.go b/pkg/network/peer/node.go new file mode 100644 index 0000000..254a590 --- /dev/null +++ b/pkg/network/peer/node.go @@ -0,0 +1,265 @@ +package peer + +import ( + "context" + "crypto/ed25519" + "crypto/tls" + "fmt" + "log" + "net" + "sync" + "time" + + "github.com/eigerco/strawberry/internal/crypto" + "github.com/eigerco/strawberry/pkg/network/cert" + "github.com/eigerco/strawberry/pkg/network/handlers" + "github.com/eigerco/strawberry/pkg/network/protocol" + "github.com/eigerco/strawberry/pkg/network/transport" +) + +// Node manages peer connections, handles protocol messages, and coordinates network operations. +// Each Node can act as both a client and server, maintaining connections with multiple peers simultaneously. +type Node struct { + Context context.Context + Cancel context.CancelFunc + transport *transport.Transport + protocolManager *protocol.Manager + peersLock sync.RWMutex + peersSet *PeerSet + blockRequester *handlers.BlockRequester +} + +// ValidatorKeys holds the cryptographic keys required for a validator node. +// These keys are used for signing messages, participating in consensus, +// and establishing secure connections with other nodes. +type ValidatorKeys struct { + EdPrv ed25519.PrivateKey + EdPub ed25519.PublicKey + BanderPrv crypto.BandersnatchPrivateKey + BanderPub crypto.BandersnatchPublicKey + Bls crypto.BlsKey +} + +// PeerSet maintains mappings between peer identifiers +// (Ed25519 keys, network addresses, validator indices) and Peer objects. +type PeerSet struct { + // Map from Ed25519 public key to peer + byEd25519Key map[string]*Peer + // Map from string representation of address to peer + byAddress map[string]*Peer + // Map from validator index to peer (only for validator peers) + byValidatorIndex map[uint16]*Peer +} + +// NewPeerSet creates a new PeerSet instance with initialized internal maps. +func NewPeerSet() *PeerSet { + return &PeerSet{ + byEd25519Key: make(map[string]*Peer), + byAddress: make(map[string]*Peer), + byValidatorIndex: make(map[uint16]*Peer), + } +} + +// AddPeer adds a peer to all relevant lookup maps in the PeerSet. +// If the peer is a validator index, it will also have a validator index. +func (ps *PeerSet) AddPeer(peer *Peer) { + ps.byEd25519Key[string(peer.Ed25519Key)] = peer + ps.byAddress[peer.Address.String()] = peer + + if peer.ValidatorIndex != nil { + ps.byValidatorIndex[*peer.ValidatorIndex] = peer + } +} + +// RemovePeer removes a peer from all lookup maps in the PeerSet. +func (ps *PeerSet) RemovePeer(peer *Peer) { + delete(ps.byEd25519Key, string(peer.Ed25519Key)) + delete(ps.byAddress, peer.Address.String()) + + if peer.ValidatorIndex != nil { + delete(ps.byValidatorIndex, *peer.ValidatorIndex) + } +} + +// GetByEd25519Key looks up a peer by their Ed25519 public key. +// Returns nil if no peer is found with the given key. +func (ps *PeerSet) GetByEd25519Key(key ed25519.PublicKey) *Peer { + return ps.byEd25519Key[string(key)] +} + +// GetByAddress looks up a peer by their network address. +// Returns nil if no peer is found with the given address. +func (ps *PeerSet) GetByAddress(addr string) *Peer { + return ps.byAddress[addr] +} + +// GetByValidatorIndex looks up a peer by their validator index. +// Returns nil if no peer is found with the given validator index. +func (ps *PeerSet) GetByValidatorIndex(index uint16) *Peer { + return ps.byValidatorIndex[index] +} + +// NewNode creates a new Node instance with the specified configuration. +// It initializes the TLS certificate, protocol manager, and network transport. +func NewNode(nodeCtx context.Context, listenAddr *net.UDPAddr, keys ValidatorKeys) (*Node, error) { + nodeCtx, cancel := context.WithCancel(nodeCtx) + node := &Node{ + peersSet: NewPeerSet(), + Context: nodeCtx, + Cancel: cancel, + } + + // Create TLS certificate using the node's Ed25519 key pair + certGen := cert.NewGenerator(cert.Config{ + PublicKey: keys.EdPub, + PrivateKey: keys.EdPrv, + CertValidityPeriod: 24 * time.Hour, + }) + tlsCert, err := certGen.GenerateCertificate() + if err != nil { + return nil, fmt.Errorf("failed to generate certificate: %w", err) + } + + // Initialize protocol manager with chain-specific configuration. + // These are just testing values. + protoConfig := protocol.Config{ + ChainHash: "12345678", + IsBuilder: true, + MaxBuilderSlots: 20, + } + protoManager, err := protocol.NewManager(protoConfig) + if err != nil { + return nil, fmt.Errorf("failed to create protocol manager: %w", err) + } + + // Register what type of streams the Node will support. + protoManager.Registry.RegisterHandler(protocol.StreamKindBlockRequest, handlers.NewBlockRequestHandler()) + + // Create transport + transportConfig := transport.Config{ + PublicKey: keys.EdPub, + PrivateKey: keys.EdPrv, + TLSCert: tlsCert, + ListenAddr: listenAddr, + CertValidator: cert.NewValidator(), + Handler: node, + Context: nodeCtx, + } + + tr, err := transport.NewTransport(transportConfig) + if err != nil { + return nil, fmt.Errorf("failed to create transport: %w", err) + } + node.transport = tr + node.protocolManager = protoManager + return node, nil +} + +// OnConnection is called by the transport layer whenever a new QUIC connection is established. +// This is a callback-style interface where transport.Conn represents an authenticated QUIC connection +// with a verified peer certificate. The connection flow is: +// +// 1. Transport layer accepts QUIC connection +// 2. TLS handshake completes, peer's Ed25519 key verified +// 3. Transport calls this OnConnection method +// 4. We check for existing connection from same peer +// 5. If exists: Close old connection (This will change soon), cleanup peer state. +// 6. Create protocol-level connection wrapper +// 7. Add new peer to connection registry +// +// This design separates transport-level connection handling (TLS, QUIC) +// from protocol-level peer management (stream handling, peer state). +func (n *Node) OnConnection(conn *transport.Conn) { + n.peersLock.Lock() + defer n.peersLock.Unlock() + // If peer already exists, close existing connection and replace with new one. + if existingPeer := n.peersSet.GetByEd25519Key(conn.PeerKey()); existingPeer != nil { + // Close existing connection + if err := existingPeer.ProtoConn.Close(); err != nil { + log.Printf("Failed to close existing peer connection: %v", err) + } + n.peersSet.RemovePeer(existingPeer) + } + + pConn := n.protocolManager.OnConnection(conn) + peer := NewPeer(pConn) + if peer == nil { + log.Printf("Failed to create peer: invalid remote address type") + // Clean up the connection since we can't use it + if err := pConn.Close(); err != nil { + log.Printf("Failed to close protocol connection: %v", err) + } + return + } + // Add to peer set + n.peersSet.AddPeer(peer) +} + +// ConnectToPeer initiates a connection to a peer at the specified address. +// It prevents duplicate connections to the same peer. +func (n *Node) ConnectToPeer(addr *net.UDPAddr) error { + // Check if peer already exists before attempting connection. + n.peersLock.RLock() + existingPeer := n.peersSet.GetByAddress(addr.String()) + n.peersLock.RUnlock() + + if existingPeer != nil { + return fmt.Errorf("peer already exists") + } + + // Establish connection + if err := n.transport.Connect(addr); err != nil { + return fmt.Errorf("failed to connect to peer: %w", err) + } + return nil +} + +// TODO somehwat of Mock atm. Will add full implementaion in the coming PR's. +func (n *Node) RequestBlock(ctx context.Context, hash crypto.Hash, ascending bool, peerKey ed25519.PublicKey) ([]byte, error) { + n.peersLock.RLock() + existingPeer := n.peersSet.GetByEd25519Key(peerKey) + n.peersLock.RUnlock() + + if existingPeer != nil { + stream, err := existingPeer.ProtoConn.OpenStream(ctx, protocol.StreamKindBlockRequest) + if err != nil { + return nil, fmt.Errorf("failed to open stream: %w", err) + } + + defer stream.Close() + blockData, err := n.blockRequester.RequestBlocks(ctx, stream, hash, ascending) + if err != nil { + return nil, fmt.Errorf("failed to request block from peer: %w", err) + } + return blockData, nil + } + return nil, fmt.Errorf("no peers available to request block from") +} + +// Start begins the node's network operations, including listening for incoming connections. +func (n *Node) Start() error { + if err := n.transport.Start(); err != nil { + return fmt.Errorf("failed to start transport: %w", err) + } + return nil +} + +// Stop gracefully shuts down the node's network operations and closes all +// peer connections. +func (n *Node) Stop() error { + n.Cancel() + return n.transport.Stop() +} + +// ValidateConnection verifies that an incoming TLS connection meets the +// protocol requirements, including certificate validation and protocol +// version checking. +func (n *Node) ValidateConnection(tlsState tls.ConnectionState) error { + return n.protocolManager.ValidateConnection(tlsState) +} + +// GetProtocols returns the list of supported protocol +// versions and variants for this node. +func (n *Node) GetProtocols() []string { + return n.protocolManager.GetProtocols() +} diff --git a/pkg/network/peer/peer.go b/pkg/network/peer/peer.go index 558c7b7..f80651d 100644 --- a/pkg/network/peer/peer.go +++ b/pkg/network/peer/peer.go @@ -2,51 +2,55 @@ package peer import ( "context" + "crypto/ed25519" "fmt" - - "github.com/eigerco/strawberry/pkg/network/handlers" "github.com/eigerco/strawberry/pkg/network/protocol" - "github.com/eigerco/strawberry/pkg/network/transport" + "net" + "net/netip" ) // Peer represents a remote peer and provides high-level protocol operations. // It wraps the underlying transport and protocol connections with a simpler interface. type Peer struct { - // conn is the underlying transport connection - conn *transport.Conn - // protoConn handles protocol-specific operations - protoConn *protocol.ProtocolConn + // ProtoConn handles protocol-specific operations + ProtoConn *protocol.ProtocolConn + Address *net.UDPAddr + ctx context.Context + cancel context.CancelFunc + Ed25519Key ed25519.PublicKey + // Optional validator index if this peer is a validator + ValidatorIndex *uint16 } // NewPeer creates a new peer instance from an established transport connection. -// It wraps the connection with protocol-specific functionality using the provided manager. -// Parameters: -// - conn: The underlying transport connection -// - pubKey: The peer's Ed25519 public key -// - protoManager: The protocol manager for handling streams -func NewPeer(conn *transport.Conn, protoManager *protocol.Manager) *Peer { - return &Peer{ - conn: conn, - protoConn: protoManager.WrapConnection(conn), +func NewPeer(pConn *protocol.ProtocolConn) *Peer { + ctx, cancel := context.WithCancel(pConn.TConn.Context()) + remoteAddr, ok := pConn.TConn.QConn.RemoteAddr().(*net.UDPAddr) + if !ok { + cancel() + return nil + } + p := &Peer{ + ProtoConn: pConn, + ctx: ctx, + cancel: cancel, + Ed25519Key: pConn.TConn.PeerKey(), + Address: remoteAddr, } + return p } -// RequestBlocks requests a sequence of blocks from the peer. -// Opens a block request stream and handles the protocol interaction. -// Parameters: -// - ctx: Context for cancellation -// - headerHash: Hash of the header to start from -// - ascending: If true, gets blocks after header, if false, gets blocks before -// -// Returns: -// - The requested blocks data or an error if the request fails -func (p *Peer) RequestBlocks(ctx context.Context, headerHash [32]byte, ascending bool) ([]byte, error) { - stream, err := p.protoConn.OpenStream(ctx, protocol.StreamKindBlockRequest) - if err != nil { - return nil, fmt.Errorf("failed to open stream: %w", err) +// The first 18 bytes of validator metadata, with the first 16 bytes being the IPv6 address +// and the latter 2 being a little endian representation of the port. +func NewPeerAddressFromMetadata(metadata []byte) (*net.UDPAddr, error) { + if len(metadata) < 18 { + return nil, fmt.Errorf("metadata too short: got %d bytes, want at least 18", len(metadata)) + } + + var address netip.AddrPort + if err := address.UnmarshalBinary(metadata[:18]); err != nil { + return nil, fmt.Errorf("failed to unmarshal address: %w", err) } - defer stream.Close() - requester := &handlers.BlockRequester{} - return requester.RequestBlocks(ctx, stream, headerHash, ascending) + return net.UDPAddrFromAddrPort(address), nil } diff --git a/pkg/network/protocol/conn.go b/pkg/network/protocol/conn.go index d22fca6..e33a463 100644 --- a/pkg/network/protocol/conn.go +++ b/pkg/network/protocol/conn.go @@ -12,19 +12,19 @@ import ( // ProtocolConn wraps a transport connection with protocol-specific functionality. // It manages stream multiplexing, handles stream kinds, and maintains unique persistent streams. type ProtocolConn struct { - tConn *transport.Conn - mu sync.RWMutex - upStreams map[StreamKind]quic.Stream - registry *JAMNPRegistry + TConn *transport.Conn + Registry *JAMNPRegistry + mu sync.RWMutex + streams map[StreamKind]quic.Stream } // NewProtocolConn creates a new protocol-level connection. // It initializes stream management and associates the connection with a handler registry. func NewProtocolConn(tConn *transport.Conn, registry *JAMNPRegistry) *ProtocolConn { return &ProtocolConn{ - tConn: tConn, - upStreams: make(map[StreamKind]quic.Stream), - registry: registry, + TConn: tConn, + streams: make(map[StreamKind]quic.Stream), + Registry: registry, } } @@ -33,7 +33,7 @@ func NewProtocolConn(tConn *transport.Conn, registry *JAMNPRegistry) *ProtocolCo // Returns an error if stream creation or initial write fails. func (pc *ProtocolConn) OpenStream(ctx context.Context, kind StreamKind) (quic.Stream, error) { // Use the passed context for opening the stream - stream, err := pc.tConn.OpenStream(ctx) + stream, err := pc.TConn.OpenStream(ctx) if err != nil { return nil, err } @@ -47,33 +47,12 @@ func (pc *ProtocolConn) OpenStream(ctx context.Context, kind StreamKind) (quic.S return stream, nil } -// TODO: to be used in the future -// handleUPStream manages unique persistent streams -// func (pc *ProtocolConn) handleUPStream(kind StreamKind, stream quic.Stream) (quic.Stream, error) { -// pc.mu.Lock() -// defer pc.mu.Unlock() - -// if existing, exists := pc.upStreams[kind]; exists { -// // Keep stream with higher ID -// if existing.StreamID() > stream.StreamID() { -// stream.Close() -// return existing, nil -// } else { -// existing.Close() -// pc.upStreams[kind] = stream -// } -// } else { -// pc.upStreams[kind] = stream -// } -// return stream, nil -// } - // AcceptStream accepts and handles an incoming stream. // It reads the stream kind byte, looks up the appropriate handler, // and starts a goroutine to handle the stream. // Returns an error if accepting the stream or reading the kind fails. func (pc *ProtocolConn) AcceptStream() error { - stream, err := pc.tConn.AcceptStream() + stream, err := pc.TConn.AcceptStream() if err != nil { return err } @@ -86,7 +65,7 @@ func (pc *ProtocolConn) AcceptStream() error { } // Get handler for this stream kind - handler, err := pc.registry.GetHandler(kind[0]) + handler, err := pc.Registry.GetHandler(kind[0]) if err != nil { stream.Close() return err @@ -94,7 +73,7 @@ func (pc *ProtocolConn) AcceptStream() error { // Handle the stream go func() { - if err := handler.HandleStream(pc.tConn.Context(), stream); err != nil { + if err := handler.HandleStream(pc.TConn.Context(), stream); err != nil { fmt.Printf("stream handler error: %v\n", err) } }() @@ -129,12 +108,12 @@ func (pc *ProtocolConn) Close() error { defer pc.mu.Unlock() // Close all UP streams - for _, stream := range pc.upStreams { + for _, stream := range pc.streams { if err := stream.Close(); err != nil { fmt.Printf("Error closing stream: %v\n", err) } } - pc.upStreams = make(map[StreamKind]quic.Stream) + pc.streams = make(map[StreamKind]quic.Stream) - return pc.tConn.Close() + return pc.TConn.Close() } diff --git a/pkg/network/protocol/manager.go b/pkg/network/protocol/manager.go index 6c48bc8..93a0277 100644 --- a/pkg/network/protocol/manager.go +++ b/pkg/network/protocol/manager.go @@ -46,16 +46,10 @@ func NewManager(config Config) (*Manager, error) { // OnConnection is called when a new transport connection is established. // It sets up a protocol connection and starts a stream handling goroutine. -// Implements the transport.ConnectionHandler interface. -func (m *Manager) OnConnection(conn *transport.Conn) error { - // Protocol connection creation could fail due to invalid parameters - protoConn, err := m.setupProtocolConn(conn) - if err != nil { - return fmt.Errorf("protocol connection setup failed: %w", err) - } +func (m *Manager) OnConnection(conn *transport.Conn) *ProtocolConn { + protoConn := NewProtocolConn(conn, m.Registry) go m.handleStreams(protoConn) - - return nil + return protoConn } // handleStreams manages the lifecycle of streams for a protocol connection. @@ -68,7 +62,7 @@ func (m *Manager) handleStreams(protoConn *ProtocolConn) { streamErr := protoConn.AcceptStream() if streamErr != nil { // Check if the connection's context has been canceled - if protoConn.tConn.Context().Err() != nil { + if protoConn.TConn.Context().Err() != nil { fmt.Println("Connection closed: context done") return } @@ -93,18 +87,6 @@ func isTimeoutError(err error) bool { return err != nil && strings.Contains(err.Error(), "timeout: no recent network activity") } -// setupProtocolConn creates and initializes a new protocol connection. -// Returns an error if the provided transport connection is nil. -func (m *Manager) setupProtocolConn(conn *transport.Conn) (*ProtocolConn, error) { - if conn == nil { - return nil, fmt.Errorf("invalid connection") - } - - protoConn := NewProtocolConn(conn, m.Registry) - - return protoConn, nil -} - // GetProtocols returns the list of supported ALPN protocol strings. // The returned protocols include both builder and non-builder variants. // Implements the transport.ConnectionHandler interface. @@ -139,9 +121,3 @@ func (m *Manager) ValidateConnection(tlsState tls.ConnectionState) error { return nil } - -// WrapConnection wraps a transport connection with protocol-specific functionality. -// Returns a new ProtocolConn that can handle protocol-specific stream operations. -func (m *Manager) WrapConnection(conn *transport.Conn) *ProtocolConn { - return NewProtocolConn(conn, m.Registry) -} diff --git a/pkg/network/protocol/streams.go b/pkg/network/protocol/streams.go index 91c3a65..fb4ebb4 100644 --- a/pkg/network/protocol/streams.go +++ b/pkg/network/protocol/streams.go @@ -1,8 +1,11 @@ package protocol import ( + "context" "fmt" - "github.com/eigerco/strawberry/pkg/network/transport" + "sync" + + "github.com/quic-go/quic-go" ) const ( @@ -29,18 +32,24 @@ const ( StreamKindJudgmentPublish StreamKind = 145 ) +// StreamHandler processes individual QUIC streams within a connection +type StreamHandler interface { + HandleStream(ctx context.Context, stream quic.Stream) error +} + // StreamKind represents the type of stream (Unique Persistent or Common Ephemeral) type StreamKind byte // JAMNPRegistry manages stream handlers for different protocol stream kinds type JAMNPRegistry struct { - handlers map[StreamKind]transport.StreamHandler + mu sync.RWMutex + handlers map[StreamKind]StreamHandler } // NewJAMNPRegistry creates a new registry for stream handlers func NewJAMNPRegistry() *JAMNPRegistry { return &JAMNPRegistry{ - handlers: make(map[StreamKind]transport.StreamHandler), + handlers: make(map[StreamKind]StreamHandler), } } @@ -54,18 +63,25 @@ func (r *JAMNPRegistry) ValidateKind(kindByte byte) error { return nil } -// RegisterHandler associates a stream handler with a specific stream kind -// The handler will be called when streams of the specified kind are opened -func (r *JAMNPRegistry) RegisterHandler(kind StreamKind, handler transport.StreamHandler) { +// RegisterHandler associates a stream handler with a specific stream kind. +// When a stream of the registered kind is opened, the corresponding handler +// will be invoked to process it. This method is called during protocol +// initialization to set up handlers for supported stream kinds. +func (r *JAMNPRegistry) RegisterHandler(kind StreamKind, handler StreamHandler) { + r.mu.Lock() + defer r.mu.Unlock() r.handlers[kind] = handler } // GetHandler retrieves the handler associated with a given stream kind byte // Returns an error if no handler is registered for the kind -func (r *JAMNPRegistry) GetHandler(kindByte byte) (transport.StreamHandler, error) { - // Convert raw byte to protocol's StreamKind here +func (r *JAMNPRegistry) GetHandler(kindByte byte) (StreamHandler, error) { + // Convert raw byte to protocol's StreamKind kind := StreamKind(kindByte) + r.mu.RLock() + defer r.mu.RUnlock() + handler, ok := r.handlers[kind] if !ok { return nil, fmt.Errorf("no handler for kind %d", kind) diff --git a/pkg/network/transport/conn.go b/pkg/network/transport/conn.go index e6ce276..a92a286 100644 --- a/pkg/network/transport/conn.go +++ b/pkg/network/transport/conn.go @@ -4,7 +4,6 @@ import ( "context" "crypto/ed25519" "fmt" - "sync" "time" "github.com/quic-go/quic-go" @@ -17,9 +16,8 @@ const StreamTimeout = 5 * time.Second // It manages the underlying QUIC connection, stream creation, // and connection lifecycle via context cancellation. type Conn struct { - qConn quic.Connection + QConn quic.Connection transport *Transport - mu sync.RWMutex // Protects peerKey peerKey ed25519.PublicKey ctx context.Context cancel context.CancelFunc @@ -32,37 +30,20 @@ func newConn(qConn quic.Connection, transport *Transport) *Conn { ctx, cancel := context.WithCancel(transport.ctx) conn := &Conn{ - qConn: qConn, + QConn: qConn, transport: transport, ctx: ctx, cancel: cancel, } - // Ensure cleanup when connection ends - go func() { - <-ctx.Done() - conn.cleanup() - }() - return conn } -// cleanup removes the connection from the transport's connection map. -// Called automatically when the connection context is cancelled. -func (c *Conn) cleanup() { - c.mu.RLock() - peerKey := c.peerKey - c.mu.RUnlock() - if peerKey != nil { - c.transport.cleanup(peerKey) - } -} - // OpenStream opens a new bidirectional QUIC stream. // The provided context can be used to cancel the stream opening operation. // Returns the new stream or an error if creation fails. func (c *Conn) OpenStream(ctx context.Context) (quic.Stream, error) { - stream, err := c.qConn.OpenStreamSync(ctx) + stream, err := c.QConn.OpenStreamSync(ctx) if err != nil { return nil, fmt.Errorf("failed to open QUIC stream: %w", err) } @@ -74,7 +55,7 @@ func (c *Conn) OpenStream(ctx context.Context) (quic.Stream, error) { // Uses the connection's context for cancellation. // Returns the accepted stream or an error if accepting fails. func (c *Conn) AcceptStream() (quic.Stream, error) { - stream, err := c.qConn.AcceptStream(c.ctx) + stream, err := c.QConn.AcceptStream(c.ctx) if err != nil { return nil, fmt.Errorf("failed to accept QUIC stream: %w", err) } @@ -84,23 +65,19 @@ func (c *Conn) AcceptStream() (quic.Stream, error) { // PeerKey returns the public key of the connected peer. // This key uniquely identifies the remote peer. func (c *Conn) PeerKey() ed25519.PublicKey { - c.mu.RLock() - defer c.mu.RUnlock() return c.peerKey } // SetPeerKey sets the peer's public key func (c *Conn) SetPeerKey(key ed25519.PublicKey) { - c.mu.Lock() c.peerKey = key - c.mu.Unlock() } // Close closes the connection and cancels all associated streams. // Returns an error if closing the QUIC connection fails. func (c *Conn) Close() error { c.cancel() - return c.qConn.CloseWithError(0, "") + return c.QConn.CloseWithError(0, "") } // Context returns the connection's context. diff --git a/pkg/network/transport/transport.go b/pkg/network/transport/transport.go index 63b79d4..aed4e26 100644 --- a/pkg/network/transport/transport.go +++ b/pkg/network/transport/transport.go @@ -6,7 +6,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "sync" + "net" "time" "github.com/quic-go/quic-go" @@ -15,19 +15,6 @@ import ( // MaxIdleTimeout defines the maximum duration a connection can be idle before timing out const MaxIdleTimeout = 30 * time.Minute -// StreamHandler processes individual QUIC streams within a connection -type StreamHandler interface { - HandleStream(ctx context.Context, stream quic.Stream) error -} - -// StreamRegistry manages stream handlers and validates stream kinds -type StreamRegistry interface { - // GetHandler returns the handler for a given stream kind byte - GetHandler(kindByte byte) (StreamHandler, error) - // ValidateKind checks if a stream kind byte is valid - ValidateKind(kindByte byte) error -} - // CertValidator performs TLS certificate validation and public key extraction type CertValidator interface { // ValidateCertificate checks if a certificate meets required criteria @@ -36,42 +23,41 @@ type CertValidator interface { ExtractPublicKey(cert *x509.Certificate) (ed25519.PublicKey, error) } -// ProtocolManager handles ALPN protocol negotiation and validation -type ProtocolManager interface { - // AcceptableProtocols returns valid protocol strings for a chain - AcceptableProtocols(chainHash string) []string - // NewProtocolID creates a protocol identifier string - NewProtocolID(chainHash string, isBuilder bool) string - // ValidateProtocol checks if a protocol string is valid - ValidateProtocol(protocol string) error -} - -// ConnectionHandler processes new connections and validates their protocols +// ConnectionHandler defines how new connections are processed at the protocol level. +// This interface separates transport-level connection handling from protocol-specific +// behaviors. type ConnectionHandler interface { - // OnConnection is called when a new connection is established - OnConnection(conn *Conn) error + // OnConnection is called after a new connection is established and authenticated. + // The handler typically sets up protocol-specific streams and state. + OnConnection(conn *Conn) // GetProtocols returns supported ALPN protocol strings GetProtocols() []string - // ValidateConnection verifies TLS connection parameters + // ValidateConnection verifies TLS parameters including ALPN protocol selection. + // This is called during the TLS handshake after certificate validation. ValidateConnection(tlsState tls.ConnectionState) error } -// Config contains all configuration parameters for a Transport +// Config holds all parameters needed to initialize a Transport. +// This includes cryptographic keys, network configuration, and handlers. type Config struct { PublicKey ed25519.PublicKey // Node's public key PrivateKey ed25519.PrivateKey // Node's private key TLSCert *tls.Certificate // TLS certificate - ListenAddr string // Address to listen on + ListenAddr *net.UDPAddr // Address to listen on CertValidator CertValidator // Certificate validator Handler ConnectionHandler // Connection handler + Context context.Context // Context for transport lifecycle } -// Transport manages QUIC connections and their lifecycles +// Transport manages the networking layer of a JAMNP-S node. +// It handles: +// - Starting/stopping the QUIC listener +// - Initiating outbound connections +// - TLS handshakes and certificate validation +// - Forwarding authenticated connections to protocol handlers type Transport struct { config Config listener *quic.Listener - mu sync.RWMutex - conns map[string]*Conn // Active connections mapped by peer key ctx context.Context cancel context.CancelFunc done chan struct{} // For clean shutdown of accept loop @@ -94,24 +80,28 @@ func NewTransport(config Config) (*Transport, error) { if err := config.CertValidator.ValidateCertificate(config.TLSCert.Leaf); err != nil { return nil, ErrInvalidCertificate } - + ctx, cancel := context.WithCancel(config.Context) return &Transport{ config: config, - conns: make(map[string]*Conn), + ctx: ctx, + cancel: cancel, }, nil } -// Start initializes the transport listener and begins accepting connections. -// Returns an error if starting the listener fails. +// Start begins listening for incoming QUIC connections. +// It configures TLS with: +// - Required client certificates +// - TLS 1.3 minimum version +// - JAMNP-S certificate validation +// - ALPN protocol validation func (t *Transport) Start() error { tlsConfig := &tls.Config{ Certificates: []tls.Certificate{*t.config.TLSCert}, NextProtos: t.config.Handler.GetProtocols(), ClientAuth: tls.RequireAnyClientCert, MinVersion: tls.VersionTLS13, - InsecureSkipVerify: true, + InsecureSkipVerify: true, // We do our own certificate validation VerifyConnection: func(cs tls.ConnectionState) error { - fmt.Printf("Negotiated Protocol: %s\n", cs.NegotiatedProtocol) if len(cs.PeerCertificates) == 0 { return fmt.Errorf("%w: no peer certificate provided", ErrInvalidCertificate) } @@ -126,7 +116,7 @@ func (t *Transport) Start() error { }, } - listener, err := quic.ListenAddr(t.config.ListenAddr, tlsConfig, &quic.Config{ + listener, err := quic.ListenAddr(t.config.ListenAddr.AddrPort().String(), tlsConfig, &quic.Config{ EnableDatagrams: true, MaxIdleTimeout: MaxIdleTimeout, }) @@ -134,7 +124,6 @@ func (t *Transport) Start() error { return fmt.Errorf("%w: %v", ErrListenerFailed, err) } - t.ctx, t.cancel = context.WithCancel(context.Background()) t.listener = listener t.done = make(chan struct{}) go func() { @@ -144,38 +133,10 @@ func (t *Transport) Start() error { return nil } -// Stop gracefully shuts down the transport and all active connections. -// Waits for the accept loop to finish before returning. -func (t *Transport) Stop() error { - // Cancel the accept loop - t.cancel() - - // Close all active connections - t.mu.Lock() - for _, conn := range t.conns { - if err := conn.Close(); err != nil { - fmt.Printf("Failed to close connection: %v\n", err) - } - } - // Clear the connection map - t.conns = make(map[string]*Conn) - t.mu.Unlock() - - // Close the listener - if t.listener != nil { - if err := t.listener.Close(); err != nil { - return fmt.Errorf("failed to close listener: %w", err) - } - } - - // Wait for the accept loop to finish - <-t.done - return nil -} - -// Connect initiates a connection to a remote peer. -// Returns the new connection or an error if connection fails. -func (t *Transport) Connect(addr string) (*Conn, error) { +// Connect initiates an outbound connection to a peer. +// The connection follows the same TLS configuration and validation +// as incoming connections. +func (t *Transport) Connect(addr *net.UDPAddr) error { tlsConf := &tls.Config{ Certificates: []tls.Certificate{*t.config.TLSCert}, NextProtos: t.config.Handler.GetProtocols(), @@ -194,40 +155,35 @@ func (t *Transport) Connect(addr string) (*Conn, error) { }, } - quicConn, err := quic.DialAddr(t.ctx, addr, tlsConf, &quic.Config{ + quicConn, err := quic.DialAddr(t.ctx, addr.AddrPort().String(), tlsConf, &quic.Config{ EnableDatagrams: true, MaxIdleTimeout: MaxIdleTimeout, }) if err != nil { - return nil, fmt.Errorf("%w: %v", ErrDialFailed, err) - } - - conn := t.handleConnection(quicConn) - if conn == nil { - return nil, ErrConnFailed + return fmt.Errorf("%w: %v", ErrDialFailed, err) } - return conn, nil -} -// GetConnection retrieves an active connection by peer key. -// Returns the connection and whether it was found. -func (t *Transport) GetConnection(peerKey string) (*Conn, bool) { - t.mu.RLock() - conn, ok := t.conns[peerKey] - t.mu.RUnlock() - return conn, ok + t.handleConnection(quicConn) + return nil } -// ListConnections returns a slice of all active connections. -func (t *Transport) ListConnections() []*Conn { - t.mu.RLock() - defer t.mu.RUnlock() - - conns := make([]*Conn, 0, len(t.conns)) - for _, conn := range t.conns { - conns = append(conns, conn) +// Stop gracefully shuts down the transport and all active connections. +// Waits for the accept loop to finish before returning. +func (t *Transport) Stop() error { + // Only call cancel if it wasn't already cancelled by parent + select { + case <-t.ctx.Done(): + // Context was already cancelled by parent + default: + t.cancel() + } + if t.listener != nil { + if err := t.listener.Close(); err != nil { + return fmt.Errorf("failed to close listener: %w", err) + } } - return conns + <-t.done + return nil } // acceptLoop continuously accepts incoming connections @@ -255,55 +211,21 @@ func (t *Transport) acceptLoop() { } } -// handleConnection processes a new QUIC connection -func (t *Transport) handleConnection(qConn quic.Connection) *Conn { +// handleConnection processes a new QUIC connection after acceptance/dialing. +// It: +// 1. Extracts the peer's Ed25519 key from their certificate +// 2. Creates a Conn wrapper around the QUIC connection +// 3. Passes the connection to the protocol handler +func (t *Transport) handleConnection(qConn quic.Connection) { peerKey, err := t.config.CertValidator.ExtractPublicKey(qConn.ConnectionState().TLS.PeerCertificates[0]) if err != nil { fmt.Printf("Failed to extract peer key: %v\n", err) if cerr := qConn.CloseWithError(0, fmt.Sprintf("%s: %v", ErrInvalidCertificate.Error(), err)); cerr != nil { fmt.Printf("Failed to close connection: %v\n", cerr) } - return nil } - conn := t.storeConnection(peerKey, qConn) - - if err := t.config.Handler.OnConnection(conn); err != nil { - t.cleanup(peerKey) - if cerr := qConn.CloseWithError(0, err.Error()); cerr != nil { - fmt.Printf("Failed to close connection: %v\n", cerr) - } - return nil - } - - return conn -} - -// storeConnection handles connection storage and replacement -func (t *Transport) storeConnection(peerKey ed25519.PublicKey, qConn quic.Connection) *Conn { - t.mu.Lock() - defer t.mu.Unlock() - - // Close existing connection if any - if existingConn, exists := t.conns[string(peerKey)]; exists { - fmt.Println("Found existing connection, closing it") - if err := existingConn.Close(); err != nil { - fmt.Printf("Failed to close existing connection: %v\n", err) - } - delete(t.conns, string(peerKey)) - } - - // Create and store new connection conn := newConn(qConn, t) conn.SetPeerKey(peerKey) - t.conns[string(peerKey)] = conn - - return conn -} - -// Cleanup removes a connection from the map -func (t *Transport) cleanup(peerKey ed25519.PublicKey) { - t.mu.Lock() - delete(t.conns, string(peerKey)) - t.mu.Unlock() + t.config.Handler.OnConnection(conn) }