Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
Elazar Gershuni committed Mar 6, 2024
1 parent 6a831c2 commit cedcc6e
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 81 deletions.
108 changes: 46 additions & 62 deletions pkg/connectionset/connectionset.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var dimensionsList = []Dimension{
icmpCode,
}

func getDimensionDomain(dim Dimension) *intervals.CanonicalIntervalSet {
func entireDimension(dim Dimension) *intervals.CanonicalIntervalSet {
switch dim {
case protocol:
return intervals.CreateFromInterval(minProtocol, maxProtocol)
Expand All @@ -70,7 +70,7 @@ func getDimensionDomain(dim Dimension) *intervals.CanonicalIntervalSet {
func getDimensionDomainsList() []*intervals.CanonicalIntervalSet {
res := make([]*intervals.CanonicalIntervalSet, len(dimensionsList))
for i := range dimensionsList {
res[i] = getDimensionDomain(dimensionsList[i])
res[i] = entireDimension(dimensionsList[i])
}
return res
}
Expand Down Expand Up @@ -129,7 +129,10 @@ func (conn *ConnectionSet) Union(other *ConnectionSet) *ConnectionSet {
if conn.IsEmpty() {
return other.Copy()
}
res := &ConnectionSet{AllowAll: false, connectionProperties: conn.connectionProperties.Union(other.connectionProperties)}
res := &ConnectionSet{
AllowAll: false,
connectionProperties: conn.connectionProperties.Union(other.connectionProperties),
}
if res.isAllConnectionsWithoutAllowAll() {
return NewConnectionSet(true)
}
Expand Down Expand Up @@ -244,7 +247,7 @@ func getProtocolStr(p int64) ProtocolStr {
}

func getDimensionStr(dimValue *intervals.CanonicalIntervalSet, dim Dimension) string {
domainValues := getDimensionDomain(dim)
domainValues := entireDimension(dim)
if dimValue.Equal(*domainValues) {
// avoid adding dimension str on full dimension values
return ""
Expand Down Expand Up @@ -330,14 +333,13 @@ func (conn *ConnectionSet) String() string {
return strings.Join(resStrings, "; ")
}

type ConnDetails []Protocol

func getCubeAsTCPItems(cube []*intervals.CanonicalIntervalSet, protocol TransportLayerProtocolName) []TCPUDP {
tcpItemsTemp := []TCPUDP{}
tcpItemsFinal := []TCPUDP{}
func getCubeAsTCPItems(cube []*intervals.CanonicalIntervalSet, protocol TransportLayerProtocolName) []Protocol {
tcpItemsTemp := []Protocol{}
// consider src ports
srcPorts := cube[srcPort]
if !srcPorts.Equal(*getDimensionDomain(srcPort)) {
if srcPorts.Equal(*entireDimension(srcPort)) {
tcpItemsTemp = append(tcpItemsTemp, TCPUDP{Protocol: protocol})
} else {
// iterate the intervals in the interval-set
for _, interval := range srcPorts.IntervalSet {
tcpRes := TCPUDP{
Expand All @@ -348,67 +350,58 @@ func getCubeAsTCPItems(cube []*intervals.CanonicalIntervalSet, protocol Transpor
}
tcpItemsTemp = append(tcpItemsTemp, tcpRes)
}
} else {
tcpItemsTemp = append(tcpItemsTemp, TCPUDP{Protocol: protocol})
}
// consider dst ports
dstPorts := cube[dstPort]
if !dstPorts.Equal(*getDimensionDomain(dstPort)) {
// iterate the intervals in the interval-set
for _, interval := range dstPorts.IntervalSet {
for _, tcpItemTemp := range tcpItemsTemp {
tcpRes := TCPUDP{
Protocol: protocol,
PortRangePair: PortRangePair{
SrcPort: tcpItemTemp.PortRangePair.SrcPort,
DstPort: PortRange{int(interval.Start), int(interval.End)},
},
}
tcpItemsFinal = append(tcpItemsFinal, tcpRes)
}
if dstPorts.Equal(*entireDimension(dstPort)) {
return tcpItemsTemp
}
tcpItemsFinal := []Protocol{}
for _, interval := range dstPorts.IntervalSet {
for _, tcpItemTemp := range tcpItemsTemp {
item, _ := tcpItemTemp.(TCPUDP)
tcpItemsFinal = append(tcpItemsFinal, TCPUDP{
Protocol: protocol,
PortRangePair: PortRangePair{
SrcPort: item.PortRangePair.SrcPort,
DstPort: PortRange{int(interval.Start), int(interval.End)},
},
})
}
} else {
tcpItemsFinal = tcpItemsTemp
}
return tcpItemsFinal
}

func getIntervalNumbers(c *intervals.CanonicalIntervalSet) []int {
res := []int{}
for _, interval := range c.IntervalSet {
for i := interval.Start; i <= interval.End; i++ {
res = append(res, int(i))
}
}
return res
}

func getCubeAsICMPItems(cube []*intervals.CanonicalIntervalSet) []ICMP {
func getCubeAsICMPItems(cube []*intervals.CanonicalIntervalSet) []Protocol {
icmpTypes := cube[icmpType]
icmpCodes := cube[icmpCode]
if icmpTypes.Equal(*getDimensionDomain(icmpType)) && icmpCodes.Equal(*getDimensionDomain(icmpCode)) {
return []ICMP{}
}
res := []ICMP{}
if icmpTypes.Equal(*getDimensionDomain(icmpCode)) {
typeNumbers := getIntervalNumbers(icmpTypes)
for _, t := range typeNumbers {
if icmpCodes.Equal(*entireDimension(icmpCode)) {
if icmpTypes.Equal(*entireDimension(icmpType)) {
return []Protocol{ICMP{}}
}
res := []Protocol{}
for _, t := range icmpTypes.Elements() {
res = append(res, ICMP{ICMPCodeType: &ICMPCodeType{Type: t}})
}
return res
}

// iterate both codes and types
typeNumbers := getIntervalNumbers(icmpTypes)
codeNumbers := getIntervalNumbers(icmpCodes)
for i := range typeNumbers {
for j := range codeNumbers {
res = append(res, ICMP{ICMPCodeType: &ICMPCodeType{Type: typeNumbers[i], Code: &codeNumbers[j]}})
res := []Protocol{}
for _, t := range icmpTypes.Elements() {
codes := icmpCodes.Elements()
for i := range codes {
c := codes[i]
if ValidateICMP(t, c) == nil {
res = append(res, ICMP{ICMPCodeType: &ICMPCodeType{Type: t, Code: &c}})
}
}
}
return res
}

type ConnDetails []Protocol

func ConnToJSONRep(c *ConnectionSet) ConnDetails {
if c == nil {
return nil // one of the connections in connectionDiff can be empty
Expand All @@ -422,22 +415,13 @@ func ConnToJSONRep(c *ConnectionSet) ConnDetails {
for _, cube := range cubes {
protocols := cube[protocol]
if protocols.Contains(TCPCode) {
tcpItems := getCubeAsTCPItems(cube, TCP)
for _, item := range tcpItems {
res = append(res, item)
}
res = append(res, getCubeAsTCPItems(cube, TCP)...)
}
if protocols.Contains(UDPCode) {
udpItems := getCubeAsTCPItems(cube, UDP)
for _, item := range udpItems {
res = append(res, item)
}
res = append(res, getCubeAsTCPItems(cube, UDP)...)
}
if protocols.Contains(ICMPCode) {
icmpItems := getCubeAsICMPItems(cube)
for _, item := range icmpItems {
res = append(res, item)
}
res = append(res, getCubeAsICMPItems(cube)...)
}
}

Expand Down
29 changes: 14 additions & 15 deletions pkg/connectionset/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package connectionset
import (
"fmt"
"log"
"slices"
)

type ICMPCodeType struct {
Expand Down Expand Up @@ -76,24 +75,24 @@ func inverseICMPType(t int) int {

//nolint:revive // magic numbers are fine here
func ValidateICMP(t, c int) error {
possibleCodes := map[int][]int{
echoReply: {0},
destinationUnreachable: {0, 1, 2, 3, 4, 5},
sourceQuench: {0},
redirect: {0, 1, 2, 3},
echo: {0},
timeExceeded: {0, 1},
parameterProblem: {0},
timestamp: {0},
timestampReply: {0},
informationRequest: {0},
informationReply: {0},
maxCodes := map[int]int{
echoReply: 0,
destinationUnreachable: 5,
sourceQuench: 0,
redirect: 3,
echo: 0,
timeExceeded: 1,
parameterProblem: 0,
timestamp: 0,
timestampReply: 0,
informationRequest: 0,
informationReply: 0,
}
options, ok := possibleCodes[t]
maxCode, ok := maxCodes[t]
if !ok {
return fmt.Errorf("invalid ICMP type %v", t)
}
if !slices.Contains(options, c) {
if c > maxCode {
return fmt.Errorf("ICMP code %v is invalid for ICMP type %v", c, t)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/connectionset/statefulness.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (conn *ConnectionSet) switchSrcDstPortsOnTCP() *ConnectionSet {
srcPorts := cube[srcPort]
dstPorts := cube[dstPort]
// if the entire domain is enabled by both src and dst no need to switch
if !srcPorts.Equal(*getDimensionDomain(srcPort)) || !dstPorts.Equal(*getDimensionDomain(dstPort)) {
if !srcPorts.Equal(*entireDimension(srcPort)) || !dstPorts.Equal(*entireDimension(dstPort)) {
newCube := copyCube(cube)
newCube[srcPort], newCube[dstPort] = newCube[dstPort], newCube[srcPort]
res.connectionProperties = res.connectionProperties.Union(hypercubes.CreateFromCube(newCube))
Expand Down
2 changes: 1 addition & 1 deletion pkg/connectionset/statefulness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type statefulnessTest struct {
// the computation applies only to the TCP protocol within those connections.
expectedIsStateful int
// expectedStatefulConn represents the subset from srcToDst which is not related to the "non-stateful" mark (*) on the srcToDst connection,
// the stateless part for TCP is srcToDst.Subtract(statefuleConn)
// the stateless part for TCP is srcToDst.Subtract(statefulConn)
expectedStatefulConn *ConnectionSet
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/connectionset/tcpudp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const (
)

const DefaultMinPort = 1
const DefaultMaxPort = 65535
const DefaultMaxPort = MaxPort

type PortRange struct {
// Minimal port; default is DefaultMinPort
Expand Down
10 changes: 9 additions & 1 deletion pkg/intervals/intervalset.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,15 @@ func (c *CanonicalIntervalSet) IsSingleNumber() bool {
}
return false
}

func (c *CanonicalIntervalSet) Elements() []int {
res := []int{}
for _, interval := range c.IntervalSet {
for i := interval.Start; i <= interval.End; i++ {
res = append(res, int(i))
}
}
return res
}
func CreateFromInterval(start, end int64) *CanonicalIntervalSet {
return &CanonicalIntervalSet{IntervalSet: []Interval{{Start: start, End: end}}}
}

0 comments on commit cedcc6e

Please sign in to comment.