From 09c34ec889215b7ea9002df8a72e49cacd0d69b6 Mon Sep 17 00:00:00 2001 From: Elazar Gershuni Date: Sun, 3 Mar 2024 16:30:09 +0200 Subject: [PATCH] move connectionset from analyzer Signed-off-by: Elazar Gershuni --- go.mod | 7 +- go.sum | 4 + pkg/connectionset/connectionset.go | 466 ++++++++++++++++++++++++ pkg/connectionset/connectionset_test.go | 50 +++ pkg/connectionset/statefulness.go | 100 +++++ pkg/connectionset/statefulness_test.go | 147 ++++++++ 6 files changed, 771 insertions(+), 3 deletions(-) create mode 100644 pkg/connectionset/connectionset.go create mode 100644 pkg/connectionset/connectionset_test.go create mode 100644 pkg/connectionset/statefulness.go create mode 100644 pkg/connectionset/statefulness_test.go diff --git a/go.mod b/go.mod index c809e5b..d3262be 100644 --- a/go.mod +++ b/go.mod @@ -2,12 +2,13 @@ module github.com/np-guard/models go 1.21 -require ( - github.com/stretchr/testify v1.8.4 -) +require github.com/stretchr/testify v1.8.4 + +require golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/np-guard/vpc-network-config-synthesis v0.1.0 github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8cf6655..82ebdc2 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,13 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/np-guard/vpc-network-config-synthesis v0.1.0 h1:yAKR2w4TXcs4ir12dQwoglIll/AeQfFNcLH0NGCQIFc= +github.com/np-guard/vpc-network-config-synthesis v0.1.0/go.mod h1:wQkZxRT4t8Ut0YwOOyIbZzVP8578sy6RvLoUngSUAeI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/connectionset/connectionset.go b/pkg/connectionset/connectionset.go new file mode 100644 index 0000000..9b2cb48 --- /dev/null +++ b/pkg/connectionset/connectionset.go @@ -0,0 +1,466 @@ +package connectionset + +import ( + "sort" + "strings" + + "github.com/np-guard/models/pkg/hypercubes" + "github.com/np-guard/models/pkg/intervals" + "github.com/np-guard/vpc-network-config-synthesis/pkg/io/jsonio" +) + +type ProtocolStr string + +const numDimensions = 5 + +const ( + // ProtocolTCP is the TCP protocol. + ProtocolTCP ProtocolStr = "TCP" + // ProtocolUDP is the UDP protocol. + ProtocolUDP ProtocolStr = "UDP" + // ProtocolICMP is the ICMP protocol. + ProtocolICMP ProtocolStr = "ICMP" +) + +const ( + MinICMPtype int64 = 0 + MaxICMPtype int64 = 255 + MinICMPcode int64 = 0 + MaxICMPcode int64 = 254 + minProtocol int64 = TCP + maxProtocol int64 = ICMP + MinPort int64 = 1 + MaxPort int64 = 65535 +) + +const ( + // since iota starts with 0, the first value + // defined here will be the default + TCP int64 = iota + UDP + ICMP +) + +const ( + AllConnections = "All Connections" + NoConnections = "No Connections" +) + +type Dimension int + +const ( + protocol Dimension = 0 + srcPort Dimension = 1 + dstPort Dimension = 2 + icmpType Dimension = 3 + icmpCode Dimension = 4 +) + +const propertySeparator string = " " + +// dimensionsList is the ordered list of dimensions in the ConnectionSet object +// this should be the only place where the order is hard-coded +var dimensionsList = []Dimension{ + protocol, + srcPort, + dstPort, + icmpType, + icmpCode, +} + +func getDimensionDomain(dim Dimension) *intervals.CanonicalIntervalSet { + switch dim { + case protocol: + return intervals.CreateFromInterval(minProtocol, maxProtocol) + case srcPort: + return intervals.CreateFromInterval(MinPort, MaxPort) + case dstPort: + return intervals.CreateFromInterval(MinPort, MaxPort) + case icmpType: + return intervals.CreateFromInterval(MinICMPtype, MaxICMPtype) + case icmpCode: + return intervals.CreateFromInterval(MinICMPcode, MaxICMPcode) + } + return nil +} + +func getDimensionDomainsList() []*intervals.CanonicalIntervalSet { + res := make([]*intervals.CanonicalIntervalSet, len(dimensionsList)) + for i := range dimensionsList { + res[i] = getDimensionDomain(dimensionsList[i]) + } + return res +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// new connection set dimensions: +// protocol +// src port +// dst port +// icmp type +// icmp code + +type ConnectionSet struct { + AllowAll bool + connectionProperties *hypercubes.CanonicalHypercubeSet + IsStateful int // default is StatefulUnknown +} + +func NewConnectionSet(all bool) *ConnectionSet { + return &ConnectionSet{AllowAll: all, connectionProperties: hypercubes.NewCanonicalHypercubeSet(numDimensions)} +} + +func NewConnectionSetWithCube(cube *hypercubes.CanonicalHypercubeSet) *ConnectionSet { + res := NewConnectionSet(false) + res.connectionProperties.Union(cube) + if res.isAllConnectionsWithoutAllowAll() { + return NewConnectionSet(true) + } + return res +} + +func (conn *ConnectionSet) Copy() *ConnectionSet { + return &ConnectionSet{ + AllowAll: conn.AllowAll, + connectionProperties: conn.connectionProperties.Copy(), + IsStateful: conn.IsStateful, + } +} + +func (conn *ConnectionSet) Intersection(other *ConnectionSet) *ConnectionSet { + if other.AllowAll { + return conn.Copy() + } + if conn.AllowAll { + return other.Copy() + } + return &ConnectionSet{AllowAll: false, connectionProperties: conn.connectionProperties.Intersection(other.connectionProperties)} +} + +func (conn *ConnectionSet) IsEmpty() bool { + if conn.AllowAll { + return false + } + return conn.connectionProperties.IsEmpty() +} + +func (conn *ConnectionSet) Union(other *ConnectionSet) *ConnectionSet { + if conn.AllowAll || other.AllowAll { + return NewConnectionSet(true) + } + if other.IsEmpty() { + return conn.Copy() + } + if conn.IsEmpty() { + return other.Copy() + } + res := &ConnectionSet{AllowAll: false, connectionProperties: conn.connectionProperties.Union(other.connectionProperties)} + if res.isAllConnectionsWithoutAllowAll() { + return NewConnectionSet(true) + } + return res +} + +func getAllPropertiesObject() *hypercubes.CanonicalHypercubeSet { + return hypercubes.CreateFromCube(getDimensionDomainsList()) +} + +func (conn *ConnectionSet) isAllConnectionsWithoutAllowAll() bool { + if conn.AllowAll { + return false + } + return conn.connectionProperties.Equals(getAllPropertiesObject()) +} + +// Subtract +// ToDo: Subtract seems to ignore IsStateful (see https://github.com/np-guard/vpc-network-config-analyzer/issues/199): +// 1. is the delta connection stateful +// 2. connectionProperties is identical but conn stateful while other is not +// the 2nd item can be computed here, with enhancement to relevant structure +// the 1st can not since we do not know where exactly the statefullness came from +func (conn *ConnectionSet) Subtract(other *ConnectionSet) *ConnectionSet { + if conn.IsEmpty() || other.IsEmpty() { + return conn + } + if other.AllowAll { + return NewConnectionSet(false) + } + var connProperites *hypercubes.CanonicalHypercubeSet + if conn.AllowAll { + connProperites = getAllPropertiesObject() + } else { + connProperites = conn.connectionProperties + } + return &ConnectionSet{AllowAll: false, connectionProperties: connProperites.Subtraction(other.connectionProperties)} +} + +func (conn *ConnectionSet) ContainedIn(other *ConnectionSet) (bool, error) { + if other.AllowAll { + return true, nil + } + if conn.AllowAll { + return false, nil + } + res, err := conn.connectionProperties.ContainedIn(other.connectionProperties) + return res, err +} + +func (conn *ConnectionSet) AddTCPorUDPConn(protocol ProtocolStr, srcMinP, srcMaxP, dstMinP, dstMaxP int64) { + var cube *hypercubes.CanonicalHypercubeSet + switch protocol { + case ProtocolTCP: + cube = hypercubes.CreateFromCubeShort(TCP, TCP, srcMinP, srcMaxP, dstMinP, dstMaxP, MinICMPtype, MaxICMPtype, MinICMPcode, MaxICMPcode) + case ProtocolUDP: + cube = hypercubes.CreateFromCubeShort(UDP, UDP, srcMinP, srcMaxP, dstMinP, dstMaxP, MinICMPtype, MaxICMPtype, MinICMPcode, MaxICMPcode) + } + conn.connectionProperties = conn.connectionProperties.Union(cube) + // check if all connections allowed after this union + if conn.isAllConnectionsWithoutAllowAll() { + conn.AllowAll = true + conn.connectionProperties = hypercubes.NewCanonicalHypercubeSet(numDimensions) + } +} + +func (conn *ConnectionSet) AddICMPConnection(icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax int64) { + cube := hypercubes.CreateFromCubeShort(ICMP, ICMP, MinPort, MaxPort, MinPort, MaxPort, icmpTypeMin, icmpTypeMax, icmpCodeMin, icmpCodeMax) + conn.connectionProperties = conn.connectionProperties.Union(cube) + // check if all connections allowed after this union + if conn.isAllConnectionsWithoutAllowAll() { + conn.AllowAll = true + conn.connectionProperties = hypercubes.NewCanonicalHypercubeSet(numDimensions) + } +} + +func (conn *ConnectionSet) Equal(other *ConnectionSet) bool { + if conn.AllowAll != other.AllowAll { + return false + } + if conn.AllowAll { + return true + } + return conn.connectionProperties.Equals(other.connectionProperties) +} + +func getProtocolStr(p int64) string { + switch p { + case TCP: + return string(ProtocolTCP) + case UDP: + return string(ProtocolUDP) + case ICMP: + return string(ProtocolICMP) + } + return "" +} + +func getDimensionStr(dimValue *intervals.CanonicalIntervalSet, dim Dimension) string { + domainValues := getDimensionDomain(dim) + if dimValue.Equal(*domainValues) { + // avoid adding dimension str on full dimension values + return "" + } + switch dim { + case protocol: + pList := []string{} + for p := minProtocol; p <= maxProtocol; p++ { + if dimValue.Contains(p) { + pList = append(pList, getProtocolStr(p)) + } + } + return "protocol: " + strings.Join(pList, ",") + case srcPort: + return "src-ports: " + dimValue.String() + case dstPort: + return "dst-ports: " + dimValue.String() + case icmpType: + return "icmp-type: " + dimValue.String() + case icmpCode: + return "icmp-code: " + dimValue.String() + } + return "" +} + +func filterEmptyPropertiesStr(inputList []string) []string { + res := []string{} + for _, propertyStr := range inputList { + if propertyStr != "" { + res = append(res, propertyStr) + } + } + return res +} + +func getICMPbasedCubeStr(protocolsValues, icmpTypeValues, icmpCodeValues *intervals.CanonicalIntervalSet) string { + strList := []string{ + getDimensionStr(protocolsValues, protocol), + getDimensionStr(icmpTypeValues, icmpType), + getDimensionStr(icmpCodeValues, icmpCode), + } + return strings.Join(filterEmptyPropertiesStr(strList), propertySeparator) +} + +func getPortBasedCubeStr(protocolsValues, srcPortsValues, dstPortsValues *intervals.CanonicalIntervalSet) string { + strList := []string{ + getDimensionStr(protocolsValues, protocol), + getDimensionStr(srcPortsValues, srcPort), + getDimensionStr(dstPortsValues, dstPort), + } + return strings.Join(filterEmptyPropertiesStr(strList), propertySeparator) +} + +func getMixedProtocolsCubeStr(protocols *intervals.CanonicalIntervalSet) string { + // TODO: make sure other dimension values are full + return getDimensionStr(protocols, protocol) +} + +func getConnsCubeStr(cube []*intervals.CanonicalIntervalSet) string { + protocols := cube[protocol] + if (protocols.Contains(TCP) || protocols.Contains(UDP)) && !protocols.Contains(ICMP) { + return getPortBasedCubeStr(cube[protocol], cube[srcPort], cube[dstPort]) + } + if protocols.Contains(ICMP) && !(protocols.Contains(TCP) || protocols.Contains(UDP)) { + return getICMPbasedCubeStr(cube[protocol], cube[icmpType], cube[icmpCode]) + } + return getMixedProtocolsCubeStr(protocols) +} + +// String returns a string representation of a ConnectionSet object +func (conn *ConnectionSet) String() string { + if conn.AllowAll { + return AllConnections + } else if conn.IsEmpty() { + return NoConnections + } + resStrings := []string{} + // get cubes and cube str per each cube + cubes := conn.connectionProperties.GetCubesList() + for _, cube := range cubes { + resStrings = append(resStrings, getConnsCubeStr(cube)) + } + + sort.Strings(resStrings) + return strings.Join(resStrings, "; ") +} + +type ConnDetails jsonio.ProtocolList + +func getCubeAsTCPItems(cube []*intervals.CanonicalIntervalSet, protocol jsonio.TcpUdpProtocol) []jsonio.TcpUdp { + tcpItemsTemp := []jsonio.TcpUdp{} + tcpItemsFinal := []jsonio.TcpUdp{} + // consider src ports + srcPorts := cube[srcPort] + if !srcPorts.Equal(*getDimensionDomain(srcPort)) { + // iterate the intervals in the interval-set + for _, interval := range srcPorts.IntervalSet { + tcpRes := jsonio.TcpUdp{Protocol: protocol, MinSourcePort: int(interval.Start), MaxSourcePort: int(interval.End)} + tcpItemsTemp = append(tcpItemsTemp, tcpRes) + } + } else { + tcpItemsTemp = append(tcpItemsTemp, jsonio.TcpUdp{Protocol: protocol}) + } + // consider dst ports + dstPorts := cube[dstPort] + if !dstPorts.Equal(*getDimensionDomain(dstPort)) { + // iterate the intervals in the interval-set + for _, interval := range dstPorts.IntervalSet { + for _, tcpItemTemp := range tcpItemsTemp { + tcpRes := jsonio.TcpUdp{ + Protocol: protocol, + MinSourcePort: tcpItemTemp.MinSourcePort, + MaxSourcePort: tcpItemTemp.MaxSourcePort, + MinDestinationPort: int(interval.Start), + MaxDestinationPort: int(interval.End), + } + tcpItemsFinal = append(tcpItemsFinal, tcpRes) + } + } + } else { + tcpItemsFinal = tcpItemsTemp + } + return tcpItemsFinal +} + +func getIntervalNumbers(c *intervals.CanonicalIntervalSet) []int { + res := []int{} + for _, interval := range c.IntervalSet { + for i := interval.Start; i <= interval.End; i++ { + res = append(res, int(i)) + } + } + return res +} + +func getCubeAsICMPItems(cube []*intervals.CanonicalIntervalSet) []jsonio.Icmp { + icmpTypes := cube[icmpType] + icmpCodes := cube[icmpCode] + if icmpTypes.Equal(*getDimensionDomain(icmpType)) && icmpCodes.Equal(*getDimensionDomain(icmpCode)) { + return []jsonio.Icmp{{Protocol: jsonio.IcmpProtocolICMP}} + } + res := []jsonio.Icmp{} + if icmpTypes.Equal(*getDimensionDomain(icmpType)) { + codeNumbers := getIntervalNumbers(icmpCodes) + for i := range codeNumbers { + res = append(res, jsonio.Icmp{Protocol: jsonio.IcmpProtocolICMP, Code: &codeNumbers[i]}) + } + return res + } + if icmpCodes.Equal(*getDimensionDomain(icmpCode)) { + typeNumbers := getIntervalNumbers(icmpTypes) + for i := range typeNumbers { + res = append(res, jsonio.Icmp{Protocol: jsonio.IcmpProtocolICMP, Type: &typeNumbers[i]}) + } + return res + } + // iterate both codes and types + typeNumbers := getIntervalNumbers(icmpTypes) + codeNumbers := getIntervalNumbers(icmpCodes) + for i := range typeNumbers { + for j := range codeNumbers { + res = append(res, jsonio.Icmp{Protocol: jsonio.IcmpProtocolICMP, Type: &typeNumbers[i], Code: &codeNumbers[j]}) + } + } + return res +} + +func ConnToJSONRep(c *ConnectionSet) ConnDetails { + if c == nil { + return nil // one of the connections in connectionDiff can be empty + } + if c.AllowAll { + return ConnDetails(jsonio.ProtocolList{jsonio.AnyProtocol{Protocol: jsonio.AnyProtocolProtocolANY}}) + } + res := jsonio.ProtocolList{} + + cubes := c.connectionProperties.GetCubesList() + for _, cube := range cubes { + protocols := cube[protocol] + if protocols.Contains(TCP) { + tcpItems := getCubeAsTCPItems(cube, jsonio.TcpUdpProtocolTCP) + for _, item := range tcpItems { + res = append(res, item) + } + } + if protocols.Contains(UDP) { + udpItems := getCubeAsTCPItems(cube, jsonio.TcpUdpProtocolUDP) + for _, item := range udpItems { + res = append(res, item) + } + } + if protocols.Contains(ICMP) { + icmpItems := getCubeAsICMPItems(cube) + for _, item := range icmpItems { + res = append(res, item) + } + } + } + + return ConnDetails(res) +} + +// NewTCPConnectionSet returns a ConnectionSet object with TCP protocol (all ports) +func NewTCPConnectionSet() *ConnectionSet { + res := NewConnectionSet(false) + res.AddTCPorUDPConn(ProtocolTCP, MinPort, MaxPort, MinPort, MaxPort) + return res +} diff --git a/pkg/connectionset/connectionset_test.go b/pkg/connectionset/connectionset_test.go new file mode 100644 index 0000000..32c77ac --- /dev/null +++ b/pkg/connectionset/connectionset_test.go @@ -0,0 +1,50 @@ +package connectionset + +import ( + "fmt" + "testing" +) + +// TODO: Add test assertions +func TestBasicConnectionSet(t *testing.T) { + c := NewConnectionSet(false) + fmt.Println(c.String()) + c.AddICMPConnection(7, 7, 5, 5) + fmt.Println(c.String()) + + d := NewConnectionSet(true) + fmt.Println(d.String()) + e := NewConnectionSet(false) + e.AddTCPorUDPConn(ProtocolTCP, 1, 65535, 1, 65535) + d = d.Subtract(e) + fmt.Println(d.String()) + d = d.Union(e) + fmt.Println(d.String()) + + fmt.Println("done") +} + +func TestBasicConnectionSet2(t *testing.T) { + c := NewConnectionSet(false) + c.AddICMPConnection(7, 7, 5, 5) + d := NewConnectionSet(true) + e := NewConnectionSet(false) + e.AddTCPorUDPConn(ProtocolTCP, 1, 65535, 1, 65535) + d = d.Subtract(e) + d = d.Subtract(c) + fmt.Println(d.String()) + + fmt.Println("done") +} + +func TestBasicConnectionSet3(t *testing.T) { + c := NewConnectionSet(false) + c.AddICMPConnection(7, 7, 5, 5) + d := NewConnectionSet(true) + d = d.Subtract(c) + d.AddICMPConnection(7, 7, 5, 5) + + fmt.Println(d.String()) + + fmt.Println("done") +} diff --git a/pkg/connectionset/statefulness.go b/pkg/connectionset/statefulness.go new file mode 100644 index 0000000..b0babe5 --- /dev/null +++ b/pkg/connectionset/statefulness.go @@ -0,0 +1,100 @@ +package connectionset + +import ( + "github.com/np-guard/models/pkg/hypercubes" + "github.com/np-guard/models/pkg/intervals" +) + +const ( + // StatefulUnknown is the default value for a ConnectionSet object, + StatefulUnknown int = iota + // StatefulTrue represents a connection object for which any allowed TCP (on all allowed src/dst ports) + // has an allowed response connection + StatefulTrue + // StatefulFalse represents a connection object for which there exists some allowed TCP + // (on any allowed subset from the allowed src/dst ports) that does not have an allowed response connection + StatefulFalse +) + +// EnhancedString returns a connection string with possibly added asterisk for stateless connection +func (conn *ConnectionSet) EnhancedString() string { + if conn.IsStateful == StatefulFalse { + return conn.String() + " *" + } + return conn.String() +} + +// ConnectionWithStatefulness updates `conn` object with `IsStateful` property, based on input `secondDirectionConn`. +// `conn` represents a src-to-dst connection, and `secondDirectionConn` represents dst-to-src connection. +// The property `IsStateful` of `conn` is set as `StatefulFalse` if there is at least some subset within TCP from `conn` +// which is not stateful (such that the response direction for this subset is not enabled). +// This function also returns a connection object with the exact subset of the stateful part (within TCP) +// from the entire connection `conn`, and with the original connections on other protocols. +func (conn *ConnectionSet) ConnectionWithStatefulness(secondDirectionConn *ConnectionSet) *ConnectionSet { + connTCP := conn.tcpConn() + if connTCP.IsEmpty() { + conn.IsStateful = StatefulTrue + return conn + } + secondDirectionConnTCP := secondDirectionConn.tcpConn() + statefulCombinedConnTCP := connTCP.connTCPWithStatefulness(secondDirectionConnTCP) + conn.IsStateful = connTCP.IsStateful + nonTCP := conn.Subtract(connTCP) + return nonTCP.Union(statefulCombinedConnTCP) +} + +// connTCPWithStatefulness assumes that both `conn` and `secondDirectionConn` are within TCP. +// it assigns IsStateful a value within `conn`, and returns the subset from `conn` which is stateful. +func (conn *ConnectionSet) connTCPWithStatefulness(secondDirectionConn *ConnectionSet) *ConnectionSet { + secondDirectionSwitchPortsDirection := secondDirectionConn.switchSrcDstPortsOnTCP() + // flip src/dst ports before intersection + statefulCombinedConn := conn.Intersection(secondDirectionSwitchPortsDirection) + if !conn.Equal(statefulCombinedConn) { + conn.IsStateful = StatefulFalse + } else { + conn.IsStateful = StatefulTrue + } + return statefulCombinedConn +} + +// tcpConn returns a new ConnectionSet object, which is the intersection of `conn` with TCP +func (conn *ConnectionSet) tcpConn() *ConnectionSet { + res := NewConnectionSet(false) + res.AddTCPorUDPConn(ProtocolTCP, MinPort, MaxPort, MinPort, MaxPort) + return conn.Intersection(res) +} + +// switchSrcDstPortsOnTCP returns a new ConnectionSet object, built from the input ConnectionSet object. +// It assumes the input connection object is only within TCP protocol. +// For TCP the src and dst ports on relevant cubes are being switched. +func (conn *ConnectionSet) switchSrcDstPortsOnTCP() *ConnectionSet { + if conn.AllowAll || conn.IsEmpty() { + return conn.Copy() + } + res := NewConnectionSet(false) + cubes := conn.connectionProperties.GetCubesList() + for _, cube := range cubes { + // assuming cube[protocol] contains TCP only + srcPorts := cube[srcPort] + dstPorts := cube[dstPort] + // if the entire domain is enabled by both src and dst no need to switch + if !srcPorts.Equal(*getDimensionDomain(srcPort)) || !dstPorts.Equal(*getDimensionDomain(dstPort)) { + newCube := copyCube(cube) + newCube[srcPort], newCube[dstPort] = newCube[dstPort], newCube[srcPort] + res.connectionProperties = res.connectionProperties.Union(hypercubes.CreateFromCube(newCube)) + } else { + res.connectionProperties = res.connectionProperties.Union(hypercubes.CreateFromCube(cube)) + } + } + return res +} + +// copyCube returns a new slice of intervals copied from input cube +func copyCube(cube []*intervals.CanonicalIntervalSet) []*intervals.CanonicalIntervalSet { + newCube := make([]*intervals.CanonicalIntervalSet, len(cube)) + for i, interval := range cube { + newInterval := interval.Copy() + newCube[i] = &newInterval + } + return newCube +} diff --git a/pkg/connectionset/statefulness_test.go b/pkg/connectionset/statefulness_test.go new file mode 100644 index 0000000..095cff9 --- /dev/null +++ b/pkg/connectionset/statefulness_test.go @@ -0,0 +1,147 @@ +package connectionset + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func newTCPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *ConnectionSet { + t.Helper() + res := NewConnectionSet(false) + res.AddTCPorUDPConn(ProtocolTCP, srcMinP, srcMaxP, dstMinP, dstMaxP) + return res +} + +func newUDPConn(t *testing.T, srcMinP, srcMaxP, dstMinP, dstMaxP int64) *ConnectionSet { + t.Helper() + res := NewConnectionSet(false) + res.AddTCPorUDPConn(ProtocolUDP, srcMinP, srcMaxP, dstMinP, dstMaxP) + return res +} + +func newICMPconn(t *testing.T) *ConnectionSet { + t.Helper() + res := NewConnectionSet(false) + res.AddICMPConnection(MinICMPtype, MaxICMPtype, MinICMPcode, MaxICMPcode) + return res +} + +func allButTCP(t *testing.T) *ConnectionSet { + t.Helper() + res := NewConnectionSet(true) + tcpOnly := res.tcpConn() + return res.Subtract(tcpOnly) +} + +type statefulnessTest struct { + name string + srcToDst *ConnectionSet + dstToSrc *ConnectionSet + // expectedIsStateful represents the expected IsStateful computed value for srcToDst, + // which should be either StatefulTrue or StatefulFalse, given the input dstToSrc connection. + // the computation applies only to the TCP protocol within those connections. + expectedIsStateful int + // expectedStatefulConn represents the subset from srcToDst which is not related to the "non-stateful" mark (*) on the srcToDst connection, + // the stateless part for TCP is srcToDst.Subtract(statefuleConn) + expectedStatefulConn *ConnectionSet +} + +func (tt statefulnessTest) runTest(t *testing.T) { + t.Helper() + statefuleConn := tt.srcToDst.ConnectionWithStatefulness(tt.dstToSrc) + require.Equal(t, tt.expectedIsStateful, tt.srcToDst.IsStateful) + require.True(t, tt.expectedStatefulConn.Equal(statefuleConn)) +} + +func TestAll(t *testing.T) { + var testCasesStatefulness = []statefulnessTest{ + { + name: "tcp_all_ports_on_both_directions", + srcToDst: newTCPConn(t, MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + dstToSrc: newTCPConn(t, MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newTCPConn(t, MinPort, MaxPort, MinPort, MaxPort), // TCP all ports + }, + { + name: "first_all_cons_second_tcp_with_ports", + srcToDst: NewConnectionSet(true), // all connections + dstToSrc: newTCPConn(t, 80, 80, MinPort, MaxPort), // TCP , src-ports: 80, dst-ports: all + + // there is a subset of the tcp connection which is not stateful + expectedIsStateful: StatefulFalse, + + // TCP src-ports: all, dst-port: 80 , union: all non-TCP conns + expectedStatefulConn: allButTCP(t).Union(newTCPConn(t, MinPort, MaxPort, 80, 80)), + }, + { + name: "first_all_conns_second_no_tcp", + srcToDst: NewConnectionSet(true), // all connections + dstToSrc: newICMPconn(t), // ICMP + expectedIsStateful: StatefulFalse, + expectedStatefulConn: allButTCP(t), // UDP, ICMP (all TCP is considered stateless here) + }, + { + name: "tcp_with_ports_both_directions_exact_match", + srcToDst: newTCPConn(t, 80, 80, 443, 443), + dstToSrc: newTCPConn(t, 443, 443, 80, 80), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newTCPConn(t, 80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_partial_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 443, 443, 80, 80), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: newTCPConn(t, 80, 80, 443, 443), + }, + { + name: "tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 80, 80, 80, 80), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: NewConnectionSet(false), + }, + { + name: "udp_and_tcp_with_ports_both_directions_no_match", + srcToDst: newTCPConn(t, 80, 100, 443, 443).Union(newUDPConn(t, 80, 100, 443, 443)), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newUDPConn(t, 80, 80, 80, 80)), + expectedIsStateful: StatefulFalse, + expectedStatefulConn: newUDPConn(t, 80, 100, 443, 443), + }, + { + name: "no_tcp_in_first_direction", + srcToDst: newUDPConn(t, 80, 100, 443, 443), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newUDPConn(t, 80, 80, 80, 80)), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newUDPConn(t, 80, 100, 443, 443), + }, + { + name: "empty_conn_in_first_direction", + srcToDst: NewConnectionSet(false), + dstToSrc: newTCPConn(t, 80, 80, 80, 80).Union(newUDPConn(t, MinPort, MaxPort, MinPort, MaxPort)), + expectedIsStateful: StatefulTrue, + expectedStatefulConn: NewConnectionSet(false), + }, + { + name: "only_udp_icmp_in_first_direction_and_empty_second_direction", + srcToDst: newUDPConn(t, MinPort, MaxPort, MinPort, MaxPort).Union(newICMPconn(t)), + dstToSrc: NewConnectionSet(false), + // stateful analysis does not apply to udp/icmp, thus considered in the result as "stateful" + // (to avoid marking it as stateless in the output) + expectedIsStateful: StatefulTrue, + expectedStatefulConn: newUDPConn(t, MinPort, MaxPort, MinPort, MaxPort).Union(newICMPconn(t)), + }, + } + t.Parallel() + // explainTests is the list of tests to run + for testIdx := range testCasesStatefulness { + tt := testCasesStatefulness[testIdx] + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.runTest(t) + }) + } + fmt.Println("done") +}