-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fixes #2723 - fail on startup for cert mismatch. rebase from main, fi…
…x conflicts
- Loading branch information
1 parent
15fd986
commit c624794
Showing
4 changed files
with
278 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package tls | ||
|
||
import ( | ||
"fmt" | ||
"github.com/openziti/identity" | ||
"net" | ||
"sort" | ||
"strings" | ||
) | ||
|
||
func ValidFor(i *identity.TokenId, address string) error { | ||
if strings.HasPrefix(address, "tls:") { | ||
address = address[len("tls:"):] | ||
} | ||
|
||
host, _, err := net.SplitHostPort(address) | ||
if err != nil { | ||
return fmt.Errorf("invalid address: %s", address) | ||
} | ||
return validForHostname(i, host) | ||
} | ||
|
||
// ValidForHostname checks if the identity is valid for a given hostname | ||
func validForHostname(i *identity.TokenId, hostname string) error { | ||
var err error | ||
|
||
// Check server certificate | ||
if len(i.ServerCert()) > 0 { | ||
err = i.ServerCert()[0].Leaf.VerifyHostname(hostname) | ||
} | ||
|
||
// Check client certificate if server cert validation fails | ||
if err != nil && i.Cert() != nil && i.Cert().Leaf != nil { | ||
err = i.Cert().Leaf.VerifyHostname(hostname) | ||
} | ||
|
||
if err != nil { | ||
return fmt.Errorf("identity is not valid for provided host: [%s]. is valid for: [%v]", hostname, getUniqueAddresses(i)) | ||
} | ||
return nil | ||
} | ||
|
||
// getUniqueAddresses extracts unique DNS names and IP addresses from the identity's certificates | ||
func getUniqueAddresses(i *identity.TokenId) string { | ||
addresses := make(map[string]struct{}) | ||
|
||
if certs := i.ServerCert(); certs != nil && len(certs) > 0 && certs[0].Leaf != nil { | ||
for _, dns := range certs[0].Leaf.DNSNames { | ||
addresses[dns] = struct{}{} | ||
} | ||
for _, ip := range certs[0].Leaf.IPAddresses { | ||
addresses[ip.String()] = struct{}{} | ||
} | ||
} | ||
|
||
if cert := i.Cert(); cert != nil && cert.Leaf != nil { | ||
for _, dns := range cert.Leaf.DNSNames { | ||
addresses[dns] = struct{}{} | ||
} | ||
for _, ip := range cert.Leaf.IPAddresses { | ||
addresses[ip.String()] = struct{}{} | ||
} | ||
} | ||
|
||
uniqueList := make([]string, 0, len(addresses)) | ||
for addr := range addresses { | ||
uniqueList = append(uniqueList, addr) | ||
} | ||
sort.Strings(uniqueList) // Ensure consistent order, mostly for testing | ||
|
||
return strings.Join(uniqueList, ", ") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
package tls_test | ||
|
||
import ( | ||
"crypto/tls" | ||
"crypto/x509" | ||
"github.com/stretchr/testify/assert" | ||
"net" | ||
"testing" | ||
|
||
"github.com/openziti/identity" | ||
ziti_tls "github.com/openziti/ziti/internal/tls" | ||
) | ||
|
||
// mockIdentity implements the Identity interface for testing | ||
type mockIdentity struct { | ||
serverCerts []*tls.Certificate | ||
clientCert *tls.Certificate | ||
} | ||
|
||
func (m *mockIdentity) Cert() *tls.Certificate { return m.clientCert } | ||
func (m *mockIdentity) ServerCert() []*tls.Certificate { return m.serverCerts } | ||
func (m *mockIdentity) CA() *x509.CertPool { return nil } | ||
func (m *mockIdentity) CaPool() *identity.CaPool { return nil } | ||
func (m *mockIdentity) ServerTLSConfig() *tls.Config { return nil } | ||
func (m *mockIdentity) ClientTLSConfig() *tls.Config { return nil } | ||
func (m *mockIdentity) Reload() error { return nil } | ||
func (m *mockIdentity) WatchFiles() error { return nil } | ||
func (m *mockIdentity) StopWatchingFiles() {} | ||
func (m *mockIdentity) SetCert(pem string) error { return nil } | ||
func (m *mockIdentity) SetServerCert(pem string) error { return nil } | ||
func (m *mockIdentity) GetConfig() *identity.Config { return nil } | ||
func (m *mockIdentity) ValidFor(address string) bool { return true } | ||
|
||
const ( | ||
validDNS = "example.com" | ||
invalidDNS = "invalid.com" | ||
validIP4 = "192.168.1.1" | ||
invalidIP4 = "10.0.0.1" | ||
validIP6 = "::1" | ||
invalidIP6 = "fe80::1" | ||
validPort = "443" | ||
) | ||
|
||
// Helper to create a mock identity with certs | ||
func createmockIdentity(dnsNames []string, ipAddresses []string) *identity.TokenId { | ||
leaf := &x509.Certificate{} | ||
for _, dns := range dnsNames { | ||
leaf.DNSNames = append(leaf.DNSNames, dns) | ||
} | ||
for _, ip := range ipAddresses { | ||
leaf.IPAddresses = append(leaf.IPAddresses, net.ParseIP(ip)) | ||
} | ||
|
||
tlsCert := &tls.Certificate{Leaf: leaf} | ||
mi := &mockIdentity{ | ||
serverCerts: []*tls.Certificate{tlsCert}, | ||
clientCert: tlsCert, | ||
} | ||
id := &identity.TokenId{ | ||
Identity: mi, | ||
Token: "", | ||
Data: nil, | ||
} | ||
return id | ||
} | ||
|
||
func TestValidFor_ValidHostname(t *testing.T) { | ||
id := createmockIdentity([]string{validDNS}, []string{}) | ||
|
||
err := ziti_tls.ValidFor(id, validDNS+":"+validPort) | ||
if err != nil { | ||
t.Errorf("Expected valid hostname, got error: %v", err) | ||
} | ||
} | ||
|
||
func TestValidFor_InvalidHostname(t *testing.T) { | ||
id := createmockIdentity([]string{validDNS}, []string{}) | ||
|
||
err := ziti_tls.ValidFor(id, invalidDNS+":"+validPort) | ||
if err == nil { | ||
t.Errorf("Expected error for invalid hostname, got nil") | ||
} | ||
assert.Equal(t, "identity is not valid for provided host: ["+invalidDNS+"]. is valid for: ["+validDNS+"]", err.Error()) | ||
} | ||
|
||
func TestValidFor_ValidIPv4(t *testing.T) { | ||
id := createmockIdentity([]string{}, []string{validIP4}) | ||
|
||
err := ziti_tls.ValidFor(id, validIP4+":"+validPort) | ||
if err != nil { | ||
t.Errorf("Expected valid IP, got error: %v", err) | ||
} | ||
} | ||
|
||
func TestValidFor_InvalidIPv4(t *testing.T) { | ||
id := createmockIdentity([]string{}, []string{validIP4}) | ||
|
||
err := ziti_tls.ValidFor(id, invalidIP4+":"+validPort) | ||
if err == nil { | ||
t.Errorf("Expected error for invalid IP, got nil") | ||
} | ||
assert.Equal(t, "identity is not valid for provided host: ["+invalidIP4+"]. is valid for: ["+validIP4+"]", err.Error()) | ||
} | ||
|
||
func TestValidFor_ValidIPv6(t *testing.T) { | ||
id := createmockIdentity([]string{}, []string{validIP6}) | ||
|
||
err := ziti_tls.ValidFor(id, "["+validIP6+"]:"+validPort) | ||
if err != nil { | ||
t.Errorf("Expected valid IPv6, got error: %v", err) | ||
} | ||
} | ||
|
||
func TestValidFor_InvalidIPv6(t *testing.T) { | ||
id := createmockIdentity([]string{}, []string{validIP6}) | ||
|
||
err := ziti_tls.ValidFor(id, "["+invalidIP6+"]:"+validPort) | ||
if err == nil { | ||
t.Errorf("Expected error for invalid IPv6, got nil") | ||
} | ||
assert.Equal(t, "identity is not valid for provided host: ["+invalidIP6+"]. is valid for: ["+validIP6+"]", err.Error()) | ||
} | ||
|
||
func TestValidFor_ValidMixed(t *testing.T) { | ||
id := createmockIdentity([]string{validDNS}, []string{validIP4}) | ||
|
||
err1 := ziti_tls.ValidFor(id, validDNS+":"+validPort) | ||
err2 := ziti_tls.ValidFor(id, validIP4+":"+validPort) | ||
|
||
if err1 != nil { | ||
t.Errorf("Expected valid hostname, got error: %v", err1) | ||
} | ||
if err2 != nil { | ||
t.Errorf("Expected valid IP, got error: %v", err2) | ||
} | ||
} | ||
|
||
func TestValidFor_InvalidMixed(t *testing.T) { | ||
id := createmockIdentity([]string{validDNS}, []string{validIP4}) | ||
|
||
err1 := ziti_tls.ValidFor(id, invalidDNS+":"+validPort) | ||
err2 := ziti_tls.ValidFor(id, invalidIP4+":"+validPort) | ||
|
||
if err1 == nil { | ||
t.Errorf("Expected error for invalid hostname, got nil") | ||
} | ||
assert.Equal(t, "identity is not valid for provided host: ["+invalidDNS+"]. is valid for: ["+validIP4+", "+validDNS+"]", err1.Error()) | ||
if err2 == nil { | ||
t.Errorf("Expected error for invalid IP, got nil") | ||
} | ||
assert.Equal(t, "identity is not valid for provided host: ["+invalidIP4+"]. is valid for: ["+validIP4+", "+validDNS+"]", err2.Error()) | ||
} | ||
|
||
func TestValidFor_NoCerts(t *testing.T) { | ||
id := createmockIdentity([]string{}, []string{}) | ||
|
||
err := ziti_tls.ValidFor(id, validDNS+":"+validPort) | ||
if err == nil { | ||
t.Errorf("Expected error for no valid certs, got nil") | ||
} | ||
assert.Equal(t, "identity is not valid for provided host: ["+validDNS+"]. is valid for: []", err.Error()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters