diff --git a/transport/tlscommon/types.go b/transport/tlscommon/types.go index 2dbf5359..5d178302 100644 --- a/transport/tlscommon/types.go +++ b/transport/tlscommon/types.go @@ -159,28 +159,37 @@ func (m TLSVerificationMode) MarshalText() ([]byte, error) { return nil, fmt.Errorf("could not marshal '%+v' to text", m) } -// Unpack unpacks the string into constants. +// Unpack unpacks the input into a TLSVerificationMode. func (m *TLSVerificationMode) Unpack(in interface{}) error { if in == nil { *m = VerifyFull return nil } - s, ok := in.(string) - if !ok { - return fmt.Errorf("verification mode must be an identifier") - } - if s == "" { - *m = VerifyFull - return nil + switch o := in.(type) { + case string: + if o == "" { + *m = VerifyFull + return nil + } + + mode, found := tlsVerificationModes[o] + if !found { + return fmt.Errorf("unknown verification mode '%v'", o) + } + *m = mode + case uint64: + *m = TLSVerificationMode(o) + default: + return fmt.Errorf("verification mode is an unknown type: %T", o) } + return nil +} - mode, found := tlsVerificationModes[s] - if !found { - return fmt.Errorf("unknown verification mode '%v'", s) +func (m *TLSVerificationMode) Validate() error { + if *m > VerifyStrict { + return fmt.Errorf("unsupported verification mode: %v", m) } - - *m = mode return nil } @@ -214,13 +223,20 @@ func (m *TLSClientAuth) Unpack(s string) error { type CipherSuite uint16 -func (cs *CipherSuite) Unpack(s string) error { - suite, found := tlsCipherSuites[s] - if !found { - return fmt.Errorf("invalid tls cipher suite '%v'", s) +func (cs *CipherSuite) Unpack(i interface{}) error { + switch o := i.(type) { + case string: + suite, found := tlsCipherSuites[o] + if !found { + return fmt.Errorf("invalid tls cipher suite '%v'", o) + } + + *cs = suite + case uint64: + *cs = CipherSuite(o) + default: + return fmt.Errorf("cipher suite is an unknown type: %T", o) } - - *cs = suite return nil } @@ -233,13 +249,20 @@ func (cs CipherSuite) String() string { type tlsCurveType tls.CurveID -func (ct *tlsCurveType) Unpack(s string) error { - t, found := tlsCurveTypes[s] - if !found { - return fmt.Errorf("invalid tls curve type '%v'", s) +func (ct *tlsCurveType) Unpack(i interface{}) error { + switch o := i.(type) { + case string: + t, found := tlsCurveTypes[o] + if !found { + return fmt.Errorf("invalid tls curve type '%v'", o) + } + + *ct = t + case uint64: + *ct = tlsCurveType(o) + default: + return fmt.Errorf("tls curve type is an unsupported input type: %T", o) } - - *ct = t return nil } @@ -252,13 +275,20 @@ func (r TLSRenegotiationSupport) String() string { return "<" + unknownType + ">" } -func (r *TLSRenegotiationSupport) Unpack(s string) error { - t, found := tlsRenegotiationSupportTypes[s] - if !found { - return fmt.Errorf("invalid tls renegotiation type '%v'", s) +func (r *TLSRenegotiationSupport) Unpack(i interface{}) error { + switch o := i.(type) { + case string: + t, found := tlsRenegotiationSupportTypes[o] + if !found { + return fmt.Errorf("invalid tls renegotiation type '%v'", o) + } + + *r = t + case uint64: + *r = TLSRenegotiationSupport(o) + default: + return fmt.Errorf("tls renegotation support is an unknown type: %T", o) } - - *r = t return nil } diff --git a/transport/tlscommon/types_test.go b/transport/tlscommon/types_test.go index 7a58c5e9..fb88d694 100644 --- a/transport/tlscommon/types_test.go +++ b/transport/tlscommon/types_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/elastic/elastic-agent-libs/config" + "github.com/elastic/go-ucfg" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -69,6 +70,36 @@ func TestLoadWithEmptyVerificationMode(t *testing.T) { assert.Equal(t, cfg.VerificationMode, VerifyFull) } +func TestRepackConfig(t *testing.T) { + cfg, err := load(` + enabled: true + verification_mode: certificate + supported_protocols: [TLSv1.1, TLSv1.2] + cipher_suites: + - RSA-AES-256-CBC-SHA + certificate_authorities: + - /path/to/ca.crt + certificate: /path/to/cert.crt + key: /path/to/key.crt + curve_types: + - P-521 + renegotiation: freely + ca_sha256: + - example + ca_trusted_fingerprint: fingerprint + `) + + assert.NoError(t, err) + assert.Equal(t, cfg.VerificationMode, VerifyCertificate) + + tmp, err := ucfg.NewFrom(cfg) + assert.NoError(t, err) + + err = tmp.Unpack(cfg) + assert.NoError(t, err) + assert.Equal(t, cfg.VerificationMode, VerifyCertificate) +} + func TestTLSClientAuthUnpack(t *testing.T) { tests := []struct { val string diff --git a/transport/tlscommon/versions.go b/transport/tlscommon/versions.go index e630ef4d..3e2ad498 100644 --- a/transport/tlscommon/versions.go +++ b/transport/tlscommon/versions.go @@ -38,12 +38,25 @@ func (v TLSVersion) Details() *TLSVersionDetails { } // Unpack transforms the string into a constant. -func (v *TLSVersion) Unpack(s string) error { - version, found := tlsProtocolVersions[s] - if !found { - return fmt.Errorf("invalid tls version '%v'", s) +func (v *TLSVersion) Unpack(i interface{}) error { + switch o := i.(type) { + case string: + version, found := tlsProtocolVersions[o] + if !found { + return fmt.Errorf("invalid tls version '%v'", o) + } + *v = version + case uint64: + *v = TLSVersion(o) + default: + return fmt.Errorf("tls version is an unknown type: %T", o) } + return nil +} - *v = version +func (v *TLSVersion) Validate() error { + if *v < TLSVersionMin || *v > TLSVersionMax { + return fmt.Errorf("unsupported tls version: %v", v) + } return nil }