diff --git a/pkg/transparentproxy/config/config.go b/pkg/transparentproxy/config/config.go index 68b171614c35..f6e7fdb88d80 100644 --- a/pkg/transparentproxy/config/config.go +++ b/pkg/transparentproxy/config/config.go @@ -28,8 +28,41 @@ type Owner struct { // ranges and multiple values can be mixed e.g. 1000,1005:1006 meaning 1000,1005,1006 type ValueOrRangeList string +// NewValueOrRangeList creates a ValueOrRangeList from a given value or range of +// values. It accepts a parameter of type []uint16, uint16, or string and +// converts it to a ValueOrRangeList, which is a comma-separated string +// representation of the values. +// +// Args: +// - v (T): The input value which can be a slice of uint16, a single uint16, +// or a string. +// +// Returns: +// - ValueOrRangeList: A comma-separated string representation of the input +// values. +// +// The function panics if an unsupported type is provided, although the type +// constraints should prevent this from occurring. +func NewValueOrRangeList[T ~[]uint16 | ~uint16 | ~string](v T) ValueOrRangeList { + switch value := any(v).(type) { + case []uint16: + var ports []string + for _, port := range value { + ports = append(ports, strconv.Itoa(int(port))) + } + return ValueOrRangeList(strings.Join(ports, ",")) + case uint16: + return ValueOrRangeList(strconv.Itoa(int(value))) + case string: + return ValueOrRangeList(value) + default: + // Shouldn't be possible to catch this + panic(errors.Errorf("invalid value type: %T", value)) + } +} + type Exclusion struct { - Protocol string + Protocol ProtocolL4 Address string UIDs ValueOrRangeList Ports ValueOrRangeList @@ -771,19 +804,20 @@ func parseExcludePortsForUIDs(exclusionRules []string) ([]Exclusion, error) { return nil, errors.Wrap(err, "invalid UID range") } - var protocols []string + var protocols []ProtocolL4 if protocolOpts == "" || protocolOpts == "*" { - protocols = []string{"tcp", "udp"} + protocols = []ProtocolL4{ProtocolTCP, ProtocolUDP} } else { - for _, p := range strings.Split(protocolOpts, ",") { - pCleaned := strings.ToLower(strings.TrimSpace(p)) - if pCleaned != "tcp" && pCleaned != "udp" { - return nil, errors.Errorf( - "invalid or unsupported protocol: '%s'", - pCleaned, - ) + for _, s := range strings.Split(protocolOpts, ",") { + if p := ParseProtocolL4(s); p != ProtocolUndefined { + protocols = append(protocols, p) + continue } - protocols = append(protocols, pCleaned) + + return nil, errors.Errorf( + "invalid or unsupported protocol: '%s'", + s, + ) } } diff --git a/pkg/transparentproxy/config/config_executables_functionality.go b/pkg/transparentproxy/config/config_executables_functionality.go index 2bcfce0d70a5..7026995a8f2f 100644 --- a/pkg/transparentproxy/config/config_executables_functionality.go +++ b/pkg/transparentproxy/config/config_executables_functionality.go @@ -32,6 +32,7 @@ type FunctionalityModules struct { // more connection tracking information than the "state" match. // ref. iptables-extensions(8) > conntrack Conntrack bool + Multiport bool } type FunctionalityChains struct { @@ -120,6 +121,7 @@ func verifyFunctionality( Udp: verifyModule(ctx, iptables, ModuleUdp), Comment: verifyModule(ctx, iptables, ModuleComment), Conntrack: verifyModule(ctx, iptables, ModuleConntrack), + Multiport: verifyModule(ctx, iptables, ModuleMultiport), }, } diff --git a/pkg/transparentproxy/config/config_suite_test.go b/pkg/transparentproxy/config/config_suite_test.go new file mode 100644 index 000000000000..275d9e89b63c --- /dev/null +++ b/pkg/transparentproxy/config/config_suite_test.go @@ -0,0 +1,13 @@ +package config_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func Test(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Transparent Proxy Config Suite") +} diff --git a/pkg/transparentproxy/config/config_test.go b/pkg/transparentproxy/config/config_test.go new file mode 100644 index 000000000000..90391193e15f --- /dev/null +++ b/pkg/transparentproxy/config/config_test.go @@ -0,0 +1,30 @@ +package config + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("NewValueOrRangeList", func() { + DescribeTable("should create ValueOrRangeList", + func(input interface{}, expected string) { + // when + var result ValueOrRangeList + + switch v := input.(type) { + case []uint16: + result = NewValueOrRangeList(v) + case uint16: + result = NewValueOrRangeList(v) + case string: + result = NewValueOrRangeList(v) + } + + // then + Expect(string(result)).To(Equal(expected)) + }, + Entry("from uint16 slice", []uint16{80, 443}, "80,443"), + Entry("from single uint16", uint16(8080), "8080"), + Entry("from string", "1000-2000", "1000-2000"), + ) +}) diff --git a/pkg/transparentproxy/iptables/builder/builder_table_mangle.go b/pkg/transparentproxy/iptables/builder/builder_table_mangle.go index 1546bfe77c95..f84c4fe561bd 100644 --- a/pkg/transparentproxy/iptables/builder/builder_table_mangle.go +++ b/pkg/transparentproxy/iptables/builder/builder_table_mangle.go @@ -17,7 +17,7 @@ func buildMangleTable(cfg config.InitializedConfigIPvX) *tables.MangleTable { if cfg.DropInvalidPackets { mangle.Prerouting().AddRules( rules. - NewRule( + NewAppendRule( Match(Conntrack(Ctstate(INVALID))), Jump(Drop()), ). diff --git a/pkg/transparentproxy/iptables/builder/builder_table_nat.go b/pkg/transparentproxy/iptables/builder/builder_table_nat.go index 0949ad97f855..b7c836c1fa23 100644 --- a/pkg/transparentproxy/iptables/builder/builder_table_nat.go +++ b/pkg/transparentproxy/iptables/builder/builder_table_nat.go @@ -18,7 +18,7 @@ func buildMeshInbound(cfg config.InitializedTrafficFlow) *Chain { if !cfg.Enabled { return meshInbound.AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp()), Jump(Return()), ). @@ -29,7 +29,7 @@ func buildMeshInbound(cfg config.InitializedTrafficFlow) *Chain { for _, port := range cfg.IncludePorts { meshInbound.AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp(DestinationPort(port))), Jump(ToUserDefinedChain(cfg.RedirectChainName)), ). @@ -41,7 +41,7 @@ func buildMeshInbound(cfg config.InitializedTrafficFlow) *Chain { for _, port := range cfg.ExcludePorts { meshInbound.AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp(DestinationPort(port))), Jump(Return()), ). @@ -51,7 +51,7 @@ func buildMeshInbound(cfg config.InitializedTrafficFlow) *Chain { meshInbound.AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp()), Jump(ToUserDefinedChain(cfg.RedirectChainName)), ). @@ -68,7 +68,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { if !cfg.Redirect.Outbound.Enabled { return meshOutbound.AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp()), Jump(Return()), ). @@ -80,7 +80,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { for _, port := range cfg.Redirect.Outbound.ExcludePorts { meshOutbound.AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp(DestinationPort(port))), Jump(Return()), ). @@ -117,14 +117,14 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { // localhost:7777 AddRules( rules. - NewRule( + NewAppendRule( Source(Address(cfg.InboundPassthroughCIDR)), OutInterface(cfg.LoopbackInterfaceName), Jump(Return()), ). WithCommentf("prevent traffic loops by ensuring traffic from the sidecar proxy (using %s) to loopback interface is not redirected again", cfg.InboundPassthroughCIDR), rules. - NewRule( + NewAppendRule( Protocol(Tcp(NotDestinationPortIfBool(cfg.Redirect.DNS.Enabled, DNSPort))), OutInterface(cfg.LoopbackInterfaceName), NotDestination(cfg.LocalhostCIDR), @@ -133,7 +133,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { ). WithCommentf("redirect outbound TCP traffic (except to DNS port %d) destined for loopback interface, but not targeting address %s, and owned by UID %s (kuma-dp user) to %s chain for proper handling", DNSPort, cfg.LocalhostCIDR, cfg.Owner.UID, cfg.Redirect.Inbound.RedirectChainName), rules. - NewRule( + NewAppendRule( Protocol(Tcp(NotDestinationPortIfBool(cfg.Redirect.DNS.Enabled, DNSPort))), OutInterface(cfg.LoopbackInterfaceName), Match(Owner(NotUid(cfg.Owner.UID))), @@ -141,7 +141,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { ). WithCommentf("return outbound TCP traffic (except to DNS port %d) destined for loopback interface, owned by any UID other than %s (kuma-dp user)", DNSPort, cfg.Owner.UID), rules. - NewRule( + NewAppendRule( Match(Owner(Uid(cfg.Owner.UID))), Jump(Return()), ). @@ -152,7 +152,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { if cfg.Redirect.DNS.CaptureAll { meshOutbound.AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp(DestinationPort(DNSPort))), Jump(ToPort(cfg.Redirect.DNS.Port)), ). @@ -162,7 +162,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { for _, dnsIp := range cfg.Redirect.DNS.Servers { meshOutbound.AddRules( rules. - NewRule( + NewAppendRule( Destination(dnsIp), Protocol(Tcp(DestinationPort(DNSPort))), Jump(ToPort(cfg.Redirect.DNS.Port)), @@ -175,7 +175,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { meshOutbound.AddRules( rules. - NewRule( + NewAppendRule( Destination(cfg.LocalhostCIDR), Jump(Return()), ). @@ -185,7 +185,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { for _, port := range cfg.Redirect.Outbound.IncludePorts { meshOutbound.AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp(DestinationPort(port))), Jump(ToUserDefinedChain(cfg.Redirect.Outbound.RedirectChainName)), ). @@ -196,7 +196,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { if len(cfg.Redirect.Outbound.IncludePorts) == 0 { meshOutbound.AddRules( rules. - NewRule( + NewAppendRule( Jump(ToUserDefinedChain(cfg.Redirect.Outbound.RedirectChainName)), ). WithComment("redirect all other outbound traffic to our custom chain for further processing"), @@ -212,7 +212,7 @@ func buildMeshOutbound(cfg config.InitializedConfigIPvX) *Chain { func buildMeshRedirect(cfg config.InitializedTrafficFlow) *Chain { return MustNewChain(TableNat, cfg.RedirectChainName).AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp()), Jump(ToPort(cfg.Port)), ). @@ -221,30 +221,26 @@ func buildMeshRedirect(cfg config.InitializedTrafficFlow) *Chain { } func addOutputRules(cfg config.InitializedConfigIPvX, nat *tables.NatTable) { - rulePosition := uint(1) - if cfg.Log.Enabled { nat.Output().AddRules( rules. - NewRule(Jump(Log(OutputLogPrefix, cfg.Log.Level))). - WithPosition(rulePosition). + NewInsertRule(Jump(Log(OutputLogPrefix, cfg.Log.Level))). WithComment("log matching packets using kernel logging"), ) - rulePosition++ } for _, exclusion := range cfg.Redirect.Outbound.Exclusions { nat.Output().AddRules( rules. - NewRule( + NewInsertRule( MatchIf(exclusion.Ports != "", Multiport()), Protocol( TcpIf( - exclusion.Protocol == TCP, + exclusion.Protocol == ProtocolTCP, DestinationPortRangeOrValue(exclusion), ), UdpIf( - exclusion.Protocol == UDP, + exclusion.Protocol == ProtocolUDP, DestinationPortRangeOrValue(exclusion), ), ), @@ -255,17 +251,15 @@ func addOutputRules(cfg config.InitializedConfigIPvX, nat *tables.NatTable) { Destination(exclusion.Address), Jump(Return()), ). - WithPosition(rulePosition). WithComment("skip further processing for configured IP addresses, ports and UIDs"), ) - rulePosition++ } // Conditionally add DNS redirection rules if DNS redirection is enabled. if cfg.Redirect.DNS.Enabled { nat.Output().AddRules( rules. - NewRule( + NewInsertRule( Protocol(Udp(DestinationPort(DNSPort))), Match(Owner(Uid(cfg.Owner.UID))), JumpConditional( @@ -274,7 +268,6 @@ func addOutputRules(cfg config.InitializedConfigIPvX, nat *tables.NatTable) { Return(), // else RETURN ), ). - WithPosition(rulePosition). WithConditionalComment( cfg.Executables.Functionality.Chains.DockerOutput, fmt.Sprintf( @@ -284,38 +277,34 @@ func addOutputRules(cfg config.InitializedConfigIPvX, nat *tables.NatTable) { "return early for DNS traffic from kuma-dp", ), ) - rulePosition++ if cfg.Redirect.DNS.CaptureAll { nat.Output().AddRules( rules. - NewRule( + NewInsertRule( Protocol(Udp(DestinationPort(DNSPort))), Jump(ToPort(cfg.Redirect.DNS.Port)), ). - WithPosition(rulePosition). WithCommentf("redirect all DNS requests to the kuma-dp DNS proxy (listening on port %d)", cfg.Redirect.DNS.Port), ) } else { for _, dnsIp := range cfg.Redirect.DNS.Servers { nat.Output().AddRules( rules. - NewRule( + NewInsertRule( Destination(dnsIp), Protocol(Udp(DestinationPort(DNSPort))), Jump(ToPort(cfg.Redirect.DNS.Port)), ). - WithPosition(rulePosition). WithCommentf("redirect DNS requests to %s to the kuma-dp DNS proxy (listening on port %d)", dnsIp, cfg.Redirect.DNS.Port), ) - rulePosition++ } } } nat.Output().AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp()), Jump(ToUserDefinedChain(cfg.Redirect.Outbound.ChainName)), ). @@ -326,13 +315,11 @@ func addOutputRules(cfg config.InitializedConfigIPvX, nat *tables.NatTable) { // addPreroutingRules adds rules to the PREROUTING chain of the NAT table to // handle inbound traffic according to the provided configuration. func addPreroutingRules(cfg config.InitializedConfigIPvX, nat *tables.NatTable) { - rulePosition := uint(1) - // Add a logging rule if logging is enabled. if cfg.Log.Enabled { nat.Prerouting().AddRules( rules. - NewRule(Jump(Log(PreroutingLogPrefix, cfg.Log.Level))). + NewAppendRule(Jump(Log(PreroutingLogPrefix, cfg.Log.Level))). WithComment("log matching packets using kernel logging"), ) } @@ -340,7 +327,7 @@ func addPreroutingRules(cfg config.InitializedConfigIPvX, nat *tables.NatTable) if len(cfg.Redirect.VNet.InterfaceCIDRs) == 0 { nat.Prerouting().AddRules( rules. - NewRule( + NewAppendRule( Protocol(Tcp()), Jump(ToUserDefinedChain(cfg.Redirect.Inbound.ChainName)), ). @@ -352,34 +339,30 @@ func addPreroutingRules(cfg config.InitializedConfigIPvX, nat *tables.NatTable) for _, iface := range maps.SortedKeys(cfg.Redirect.VNet.InterfaceCIDRs) { nat.Prerouting().AddRules( rules. - NewRule( + NewInsertRule( InInterface(iface), Match(MatchUdp()), Protocol(Udp(DestinationPort(DNSPort))), Jump(ToPort(cfg.Redirect.DNS.Port)), ). - WithPosition(rulePosition). WithCommentf("redirect DNS requests on interface %s to the kuma-dp DNS proxy (listening on port %d)", iface, cfg.Redirect.DNS.Port), rules. - NewRule( + NewInsertRule( NotDestination(cfg.Redirect.VNet.InterfaceCIDRs[iface]), InInterface(iface), Protocol(Tcp()), Jump(ToPort(cfg.Redirect.Outbound.Port)), ). - WithPosition(rulePosition+1). WithCommentf("redirect TCP traffic on interface %s, excluding destination %s, to the envoy's outbound passthrough port %d", iface, cfg.Redirect.VNet.InterfaceCIDRs[iface], cfg.Redirect.Outbound.Port), ) - rulePosition += 2 } nat.Prerouting().AddRules( rules. - NewRule( + NewInsertRule( Protocol(Tcp()), Jump(ToUserDefinedChain(cfg.Redirect.Inbound.ChainName)), ). - WithPosition(rulePosition). WithComment("redirect remaining TCP traffic to our custom chain for processing"), ) } diff --git a/pkg/transparentproxy/iptables/builder/builder_table_raw.go b/pkg/transparentproxy/iptables/builder/builder_table_raw.go index 342d07810ecf..30e0a7ace37c 100644 --- a/pkg/transparentproxy/iptables/builder/builder_table_raw.go +++ b/pkg/transparentproxy/iptables/builder/builder_table_raw.go @@ -18,14 +18,14 @@ func buildRawTable(cfg config.InitializedConfigIPvX) *tables.RawTable { if cfg.Redirect.DNS.ConntrackZoneSplit { raw.Output().AddRules( rules. - NewRule( + NewAppendRule( Protocol(Udp(DestinationPort(DNSPort))), Match(Owner(Uid(cfg.Owner.UID))), Jump(Ct(Zone("1"))), ). WithCommentf("assign connection tracking zone 1 to DNS traffic from the kuma-dp user (UID %s)", cfg.Owner.UID), rules. - NewRule( + NewAppendRule( Protocol(Udp(SourcePort(cfg.Redirect.DNS.Port))), Match(Owner(Uid(cfg.Owner.UID))), Jump(Ct(Zone("2"))), @@ -36,7 +36,7 @@ func buildRawTable(cfg config.InitializedConfigIPvX) *tables.RawTable { if cfg.Redirect.DNS.CaptureAll { raw.Output().AddRules( rules. - NewRule( + NewAppendRule( Protocol(Udp(DestinationPort(DNSPort))), Jump(Ct(Zone("2"))), ). @@ -45,7 +45,7 @@ func buildRawTable(cfg config.InitializedConfigIPvX) *tables.RawTable { raw.Prerouting().AddRules( rules. - NewRule( + NewAppendRule( Protocol(Udp(SourcePort(DNSPort))), Jump(Ct(Zone("1"))), ). @@ -55,7 +55,7 @@ func buildRawTable(cfg config.InitializedConfigIPvX) *tables.RawTable { for _, ip := range cfg.Redirect.DNS.Servers { raw.Output().AddRules( rules. - NewRule( + NewAppendRule( Destination(ip), Protocol(Udp(DestinationPort(DNSPort))), Jump(Ct(Zone("2"))), @@ -64,7 +64,7 @@ func buildRawTable(cfg config.InitializedConfigIPvX) *tables.RawTable { ) raw.Prerouting().AddRules( rules. - NewRule( + NewAppendRule( Destination(ip), Protocol(Udp(SourcePort(DNSPort))), Jump(Ct(Zone("1"))), diff --git a/pkg/transparentproxy/iptables/chains/chains.go b/pkg/transparentproxy/iptables/chains/chains.go index 44d3d78848d4..dfff60d7ff92 100644 --- a/pkg/transparentproxy/iptables/chains/chains.go +++ b/pkg/transparentproxy/iptables/chains/chains.go @@ -9,9 +9,18 @@ import ( ) type Chain struct { + // The name of the iptables table (e.g., "nat", "filter") to which this + // chain belongs. table consts.TableName - name string + // The name of the iptables chain (e.g., "PREROUTING", "OUTPUT"). + name string + // A slice of rules contained within this chain. rules []*rules.Rule + // position reflects the current position for "insert" rules, indicating + // where new rules should be inserted within the chain. This is crucial for + // maintaining the correct order of rules when specific positioning is + // required. + position uint } func (c *Chain) Name() string { @@ -19,8 +28,10 @@ func (c *Chain) Name() string { } func (c *Chain) AddRules(rules ...*rules.RuleBuilder) *Chain { - for _, rule := range rules { - c.rules = append(c.rules, rule.Build(c.table, c.name)) + for _, r := range rules { + rule, newPosition := r.Build(c.table, c.name, c.position) + c.rules = append(c.rules, rule) + c.position = newPosition } return c diff --git a/pkg/transparentproxy/iptables/consts/consts.go b/pkg/transparentproxy/iptables/consts/consts.go index b100235fdee3..0467e4d83d3a 100644 --- a/pkg/transparentproxy/iptables/consts/consts.go +++ b/pkg/transparentproxy/iptables/consts/consts.go @@ -2,6 +2,7 @@ package consts import ( "regexp" + "strings" ) const ( @@ -45,10 +46,37 @@ const ( InboundPassthroughSourceAddressCIDRIPv6 = "::6/128" OutputLogPrefix = "OUTPUT:" PreroutingLogPrefix = "PREROUTING:" - UDP = "udp" - TCP = "tcp" ) +type ProtocolL4 string + +const ( + ProtocolUDP ProtocolL4 = "udp" + ProtocolTCP ProtocolL4 = "tcp" + // ProtocolUndefined represents an undefined or unsupported protocol. + ProtocolUndefined ProtocolL4 = "" +) + +// ParseProtocolL4 parses a string and returns the corresponding ProtocolL4 +// constant. If the input string is not "udp" or "tcp", it returns +// ProtocolUndefined. +// +// Args: +// - s (string): The input string representing the protocol type. +// +// Returns: +// - ProtocolL4: The parsed ProtocolL4 constant. It will be ProtocolUDP for +// "udp", ProtocolTCP for "tcp", and ProtocolUndefined for any other input +// string. +func ParseProtocolL4(s string) ProtocolL4 { + switch s := strings.ToLower(strings.TrimSpace(s)); s { + case "udp", "tcp": + return ProtocolL4(s) + default: + return ProtocolUndefined + } +} + type TableName string const ( @@ -111,6 +139,7 @@ const ( ModuleUdp = "udp" ModuleComment = "comment" ModuleConntrack = "conntrack" + ModuleMultiport = "multiport" ) type IptablesMode string diff --git a/pkg/transparentproxy/iptables/parameters/protocol.go b/pkg/transparentproxy/iptables/parameters/protocol.go index 69b3812de9dc..839fe63e6876 100644 --- a/pkg/transparentproxy/iptables/parameters/protocol.go +++ b/pkg/transparentproxy/iptables/parameters/protocol.go @@ -4,6 +4,7 @@ import ( "strconv" "github.com/kumahq/kuma/pkg/transparentproxy/config" + "github.com/kumahq/kuma/pkg/transparentproxy/iptables/consts" ) var _ ParameterBuilder = &ProtocolParameter{} @@ -127,7 +128,7 @@ func SourcePort(port uint16) *TcpUdpParameter { return sourcePort(port, false) } -func tcpUdp(proto string, params []*TcpUdpParameter) *ProtocolParameter { +func tcpUdp(proto consts.ProtocolL4, params []*TcpUdpParameter) *ProtocolParameter { var parameters []ParameterBuilder for _, parameter := range params { @@ -137,13 +138,13 @@ func tcpUdp(proto string, params []*TcpUdpParameter) *ProtocolParameter { } return &ProtocolParameter{ - name: proto, + name: string(proto), parameters: parameters, } } func Udp(udpParameters ...*TcpUdpParameter) *ProtocolParameter { - return tcpUdp("udp", udpParameters) + return tcpUdp(consts.ProtocolUDP, udpParameters) } func UdpIf(predicate bool, udpParameters ...*TcpUdpParameter) *ProtocolParameter { @@ -151,11 +152,11 @@ func UdpIf(predicate bool, udpParameters ...*TcpUdpParameter) *ProtocolParameter return nil } - return tcpUdp("udp", udpParameters) + return tcpUdp(consts.ProtocolUDP, udpParameters) } func Tcp(tcpParameters ...*TcpUdpParameter) *ProtocolParameter { - return tcpUdp("tcp", tcpParameters) + return tcpUdp(consts.ProtocolTCP, tcpParameters) } func TcpIf(predicate bool, tcpParameters ...*TcpUdpParameter) *ProtocolParameter { @@ -163,7 +164,7 @@ func TcpIf(predicate bool, tcpParameters ...*TcpUdpParameter) *ProtocolParameter return nil } - return tcpUdp("tcp", tcpParameters) + return tcpUdp(consts.ProtocolTCP, tcpParameters) } func Protocol(p ...*ProtocolParameter) *Parameter { diff --git a/pkg/transparentproxy/iptables/parameters/source.go b/pkg/transparentproxy/iptables/parameters/source.go index 1f3c1ef05e5d..efc041e50b11 100644 --- a/pkg/transparentproxy/iptables/parameters/source.go +++ b/pkg/transparentproxy/iptables/parameters/source.go @@ -15,6 +15,10 @@ func (p *SourceParameter) Negate() ParameterBuilder { } func Address(address string) *SourceParameter { + if address == "" { + return nil + } + return &SourceParameter{address: address} } @@ -28,6 +32,10 @@ func Address(address string) *SourceParameter { // // ref. iptables(8) > PARAMETERS func Source(parameter *SourceParameter) *Parameter { + if parameter == nil { + return nil + } + return &Parameter{ long: "--source", short: "-s", diff --git a/pkg/transparentproxy/iptables/rules/rules.go b/pkg/transparentproxy/iptables/rules/rules.go index fc7f5e78faa4..f06bc3810f39 100644 --- a/pkg/transparentproxy/iptables/rules/rules.go +++ b/pkg/transparentproxy/iptables/rules/rules.go @@ -20,7 +20,7 @@ type Rule struct { type RuleBuilder struct { parameters parameters.Parameters comment string - position uint + insert bool } func (b *RuleBuilder) WithComment(comment string) *RuleBuilder { @@ -47,25 +47,59 @@ func (b *RuleBuilder) WithConditionalComment( return b } -func (b *RuleBuilder) WithPosition(position uint) *RuleBuilder { - b.position = position - return b -} - -func (b *RuleBuilder) Build(table consts.TableName, chain string) *Rule { - return &Rule{ +func (b *RuleBuilder) Build( + table consts.TableName, + chain string, + position uint, +) (*Rule, uint) { + rule := &Rule{ table: table, chain: chain, - position: b.position, parameters: b.parameters, comment: b.comment, } + + if b.insert { + position++ + rule.position = position + } + + return rule, position } -func NewRule(parameters ...*parameters.Parameter) *RuleBuilder { +// NewAppendRule creates a new RuleBuilder for an iptables rule that will be +// appended to the end of an existing chain. This function takes a variable +// number of parameters, each represented as a pointer to a Parameter object, +// which specify the various conditions and actions for the rule. +// +// Args: +// - parameters (...*parameters.Parameter): A variadic list of pointers to +// Parameter objects that define the rule's conditions and actions. +// +// Returns: +// - *RuleBuilder: A pointer to a RuleBuilder configured to append a new rule +// with the specified parameters. +func NewAppendRule(parameters ...*parameters.Parameter) *RuleBuilder { return &RuleBuilder{parameters: parameters} } +// NewInsertRule creates a new RuleBuilder for an iptables rule that will be +// inserted at a specific position within an existing chain. This function takes +// a variable number of parameters, each represented as a pointer to a Parameter +// object, which specify the various conditions and actions for the rule. +// The rule will be marked for insertion rather than appending. +// +// Args: +// - parameters (...*parameters.Parameter): A variadic list of pointers to +// Parameter objects that define the rule's conditions and actions. +// +// Returns: +// - *RuleBuilder: A pointer to a RuleBuilder configured to insert a new rule +// with the specified parameters. +func NewInsertRule(parameters ...*parameters.Parameter) *RuleBuilder { + return &RuleBuilder{parameters: parameters, insert: true} +} + // BuildForRestore generates an iptables rule formatted for use with // `iptables-restore`. //