Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow for signing multiple callsign domains and private keys #66

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions api/adscert.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ message RequestInfo {
bytes url_hash = 2;
bytes body_hash = 3;
repeated SignatureInfo signature_info = 4;
// useful if 1 signatory is managing multiple origin domains such as in resellers case.
string origin_domain = 5;
}

// SignatureInfo captures the signature generated for the signing request. It
Expand Down
4 changes: 0 additions & 4 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ func main() {
logger.SetLevel(parsedLogLevel)
logger.Infof("Log Level: %s, parsed as iota %v", *logLevel, parsedLogLevel)

if *origin == "" {
logger.Fatalf("Origin ads.cert Call Sign domain name is required")
}

if *privateKey == "" {
logger.Fatalf("Private key is required")
}
Expand Down
4 changes: 4 additions & 0 deletions examples/signer-client/signer-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

var (
serverAddress = flag.String("server_address", "localhost:3000", "address of grpc server")
originDomain = flag.String("origin_domain", "", "Origin domain")
destinationURL = flag.String("url", "https://google.com/gen_204", "URL to invoke")
body = flag.String("body", "", "POST request body")
signingTimeout = flag.Duration("signing_timeout", 5*time.Millisecond, "Specifies how long this client will wait for signing to finish before abandoning.")
Expand Down Expand Up @@ -49,6 +50,9 @@ func main() {
// destination URL and body, setting these value on the RequestInfo message.
reqInfo := &api.RequestInfo{}
signatory.SetRequestInfo(reqInfo, *destinationURL, []byte(*body))
if originDomain != nil {
reqInfo.OriginDomain = *originDomain
}

// Request the signature.
logger.Infof("signing request for url: %v", *destinationURL)
Expand Down
16 changes: 8 additions & 8 deletions internal/formats/adscert_connection_signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,21 @@ func EncodeSignatureSuffix(

func NewAuthenticatedConnectionSignature(status AuthenticatedConnectionProtocolStatus, from string, invoking string) (*AuthenticatedConnectionSignature, error) {

s := &AuthenticatedConnectionSignature{}
s.status = status
s.from = from
s.invoking = invoking

if status == StatusUnspecified {
return nil, ErrParamMissingStatus
return s, ErrParamMissingStatus
}
if from == "" {
return nil, ErrParamMissingFrom
return s, ErrParamMissingFrom
}
if invoking == "" {
return nil, ErrParamMissingInvoking
return s, ErrParamMissingInvoking
}

s := &AuthenticatedConnectionSignature{}
s.status = status
s.from = from
s.invoking = invoking

return s, nil
}

Expand Down
12 changes: 4 additions & 8 deletions internal/formats/adscert_connection_signature_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
nonce string

wantNewACSErr error
wantNilACS bool
wantAddParamsForSignatureErr error
wantUnsignedBaseMessage string
wantUnsignedExtendedMessage string
Expand Down Expand Up @@ -58,7 +57,6 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
invoking: "invoking.com",

wantNewACSErr: formats.ErrParamMissingStatus,
wantNilACS: true,
},
{
desc: "check ErrParamMissingFrom",
Expand All @@ -67,7 +65,6 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
invoking: "invoking.com",

wantNewACSErr: formats.ErrParamMissingFrom,
wantNilACS: true,
},
{
desc: "check ErrParamMissingInvoking",
Expand All @@ -76,7 +73,6 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
invoking: "",

wantNewACSErr: formats.ErrParamMissingInvoking,
wantNilACS: true,
},

{
Expand Down Expand Up @@ -167,12 +163,12 @@ func TestNewAuthenticatedConnectionSignature(t *testing.T) {
t.Errorf("NewAuthenticatedConnectionSignature() %s error check: got %v, want %v", tC.desc, gotErr, tC.wantNewACSErr)
}

gotNilACS := (acs == nil)
if tC.wantNilACS != gotNilACS {
t.Fatalf("NewAuthenticatedConnectionSignature() %s nil check: got (acs == nil) %v, want %v", tC.desc, gotNilACS, tC.wantNilACS)
if acs == nil {
t.Fatalf("NewAuthenticatedConnectionSignature() %s nil check: got (acs == nil), should not be nil", tC.desc)
}

if gotNilACS {
// skip rest of tests if an error was returned
if gotErr != nil {
return
}

Expand Down
305 changes: 158 additions & 147 deletions pkg/adscert/api/adscert.pb.go

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pkg/adscert/api/adscert_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 28 additions & 19 deletions pkg/adscert/discovery/domain_indexer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func NewDefaultDomainIndexer(dnsResolver DNSResolver, domainStore DomainStore, d
domainRenewalInterval: domainRenewalInterval,
dnsResolver: dnsResolver,
domainStore: domainStore,
currentPrivateKey: make(map[string]keyAlias),
}

myPrivateKeys, err := privateKeysToKeyMap(base64PrivateKeys)
Expand All @@ -39,12 +40,14 @@ func NewDefaultDomainIndexer(dnsResolver DNSResolver, domainStore DomainStore, d
}
di.myPrivateKeys = myPrivateKeys

for _, privateKey := range di.myPrivateKeys {
// since iterating over a map is non-deterministic, we can make sure to set the key
// either if it is not already set or it is alphabetically less than current key at the index when
// iterating over the private keys map.
if di.currentPrivateKey == "" || di.currentPrivateKey < privateKey.alias {
di.currentPrivateKey = privateKey.alias
for originCallsign := range di.myPrivateKeys {
for _, privateKey := range di.myPrivateKeys[originCallsign] {
// since iterating over a map is non-deterministic, we can make sure to set the key
// either if it is not already set or it is alphabetically less than current key at the index when
// iterating over the private keys map.
if di.currentPrivateKey[originCallsign] == "" || di.currentPrivateKey[originCallsign] < privateKey.alias {
di.currentPrivateKey[originCallsign] = privateKey.alias
}
}
}

Expand All @@ -62,8 +65,8 @@ type defaultDomainIndexer struct {
lastRun time.Time
lastRunLock sync.RWMutex

myPrivateKeys keyMap
currentPrivateKey keyAlias
myPrivateKeys map[string]keyMap
currentPrivateKey map[string]keyAlias

dnsResolver DNSResolver
domainStore DomainStore
Expand Down Expand Up @@ -227,21 +230,27 @@ func (di *defaultDomainIndexer) checkDomainForKeyRecords(ctx context.Context, cu
}

// create shared secrets for each private key + public key combination
for _, myKey := range di.myPrivateKeys {
for _, theirKey := range currentDomainInfo.allPublicKeys {
keyPairAlias := newKeyPairAlias(myKey.alias, theirKey.alias)
if currentDomainInfo.allSharedSecrets[keyPairAlias] == nil {
currentDomainInfo.allSharedSecrets[keyPairAlias], err = calculateSharedSecret(myKey, theirKey)
if err != nil {
logger.Warningf("error calculating shared secret for record %s: %v", currentDomainInfo.Domain, err)
currentDomainInfo.domainStatus = DomainStatusErrorOnSharedSecretCalculation
for originCallsign := range di.myPrivateKeys {
if originCallsign != currentDomainInfo.Domain {
continue
}

for _, myKey := range di.myPrivateKeys[originCallsign] {
for _, theirKey := range currentDomainInfo.allPublicKeys {
keyPairAlias := newKeyPairAlias(myKey.alias, theirKey.alias)
if currentDomainInfo.allSharedSecrets[keyPairAlias] == nil {
currentDomainInfo.allSharedSecrets[keyPairAlias], err = calculateSharedSecret(myKey, theirKey)
if err != nil {
logger.Warningf("error calculating shared secret for record %s: %v", currentDomainInfo.Domain, err)
currentDomainInfo.domainStatus = DomainStatusErrorOnSharedSecretCalculation
}
}
}
}
}

currentDomainInfo.currentSharedSecretId = newKeyPairAlias(di.currentPrivateKey, currentDomainInfo.currentPublicKeyId)
currentDomainInfo.lastUpdateTime = time.Now()
currentDomainInfo.currentSharedSecretId = newKeyPairAlias(di.currentPrivateKey[originCallsign], currentDomainInfo.currentPublicKeyId)
currentDomainInfo.lastUpdateTime = time.Now()
}
}

func parsePolicyRecords(baseSubdomain string, baseSubdomainRecords []string) (foundDomains []string, parseError bool) {
Expand Down
23 changes: 17 additions & 6 deletions pkg/adscert/discovery/internal_base_key.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package discovery

import (
"errors"
"fmt"
"strings"

"github.com/IABTechLab/adscert/internal/formats"
"github.com/IABTechLab/adscert/pkg/adscert/logger"
Expand Down Expand Up @@ -76,11 +78,14 @@ func calculateSharedSecret(originPrivateKey *x25519Key, remotePublicKey *x25519K
return result, err
}

func privateKeysToKeyMap(privateKeys []string) (keyMap, error) {
result := keyMap{}

func privateKeysToKeyMap(privateKeys []string) (map[string]keyMap, error) {
results := map[string]keyMap{}
for _, privateKeyBase64 := range privateKeys {
privateKey, err := parseKeyFromString(privateKeyBase64)
sp := strings.SplitN(privateKeyBase64, "|", 2)
if len(sp) < 2 {
return nil, errors.New("missing origin callsign")
}
privateKey, err := parseKeyFromString(sp[1])
if err != nil {
return nil, err
}
Expand All @@ -90,10 +95,16 @@ func privateKeysToKeyMap(privateKeys []string) (keyMap, error) {

keyAlias := keyAlias(formats.ExtractKeyAliasFromPublicKeyBase64(formats.EncodeKeyBase64(publicBytes[:])))
privateKey.alias = keyAlias
result[keyAlias] = privateKey

km := results[sp[0]]
if km == nil {
km = keyMap{}
}
km[keyAlias] = privateKey
results[sp[0]] = km
}

return result, nil
return results, nil
}

func parseKeyFromString(base64EncodedKey string) (*x25519Key, error) {
Expand Down
15 changes: 13 additions & 2 deletions pkg/adscert/signatory/signatory_local_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ func NewLocalAuthenticatedConnectionsSignatory(
domainCheckInterval time.Duration,
domainRenewalInterval time.Duration,
base64PrivateKeys []string) *LocalAuthenticatedConnectionsSignatory {
if originCallsign != "" {
for i := range base64PrivateKeys {
base64PrivateKeys[i] = originCallsign + "|" + base64PrivateKeys[i]
}
}
return &LocalAuthenticatedConnectionsSignatory{
originCallsign: originCallsign,
secureRandom: secureRandom,
Expand Down Expand Up @@ -91,9 +96,15 @@ func (s *LocalAuthenticatedConnectionsSignatory) SignAuthenticatedConnection(req
}

func (s *LocalAuthenticatedConnectionsSignatory) signSingleMessage(request *api.AuthenticatedConnectionSignatureRequest, domainInfo discovery.DomainInfo) (*api.SignatureInfo, error) {

sigInfo := &api.SignatureInfo{}
acs, err := formats.NewAuthenticatedConnectionSignature(formats.StatusOK, s.originCallsign, request.RequestInfo.InvokingDomain)

var originCallsign string
if request.RequestInfo.OriginDomain != "" {
originCallsign = request.RequestInfo.OriginDomain
} else {
originCallsign = s.originCallsign
}
acs, err := formats.NewAuthenticatedConnectionSignature(formats.StatusOK, originCallsign, request.RequestInfo.InvokingDomain)
if err != nil {
acs.SetStatus(formats.StatusErrorOnSignature)
setSignatureInfoFromAuthenticatedConnection(sigInfo, acs)
Expand Down