diff --git a/pkg/service/verifycredential/linkeddomain.go b/pkg/service/verifycredential/linkeddomain.go index 4099f23f6..368b90760 100644 --- a/pkg/service/verifycredential/linkeddomain.go +++ b/pkg/service/verifycredential/linkeddomain.go @@ -23,10 +23,10 @@ type serviceEndpoint struct { Origins []string `json:"origins"` } -func (s *Service) ValidateLinkedDomain(_ context.Context, signingDID string) error { - didDocResolution, vdrErr := s.vdr.Resolve(signingDID) +func (s *Service) ValidateLinkedDomain(_ context.Context, issuerSigningDID string) error { + didDocResolution, vdrErr := s.vdr.Resolve(issuerSigningDID) if vdrErr != nil { - return fmt.Errorf("failed to resolve DID %s, err: %w", signingDID, vdrErr) + return fmt.Errorf("failed to resolve DID %s, err: %w", issuerSigningDID, vdrErr) } for _, service := range didDocResolution.DIDDocument.Service { @@ -52,11 +52,11 @@ func (s *Service) ValidateLinkedDomain(_ context.Context, signingDID string) err didconfig.WithHTTPClient(s.httpClient), ) - return didConfigurationClient.VerifyDIDAndDomain(signingDID, + return didConfigurationClient.VerifyDIDAndDomain(issuerSigningDID, strings.TrimSuffix(serviceEndpoint.Origins[0], "/")) } - return fmt.Errorf("no LinkedDomains service in DID %s", signingDID) + return fmt.Errorf("no LinkedDomains service in DID %s", issuerSigningDID) } func getServiceType(serviceType interface{}) string { diff --git a/pkg/service/verifycredential/verifycredential_service.go b/pkg/service/verifycredential/verifycredential_service.go index efd51f42b..d6be5f183 100644 --- a/pkg/service/verifycredential/verifycredential_service.go +++ b/pkg/service/verifycredential/verifycredential_service.go @@ -78,7 +78,7 @@ func (s *Service) VerifyCredential(ctx context.Context, credential *verifiable.C var result []CredentialsVerificationCheckResult if checks.LinkedDomain { - if err := s.ValidateLinkedDomain(ctx, profile.SigningDID.DID); err != nil { + if err := s.ValidateLinkedDomain(ctx, credential.Contents().Issuer.ID); err != nil { result = append(result, CredentialsVerificationCheckResult{ Check: "linkedDomain", Error: err.Error(), diff --git a/pkg/service/verifypresentation/verifypresentation_service.go b/pkg/service/verifypresentation/verifypresentation_service.go index 41a466761..1ba4e025c 100644 --- a/pkg/service/verifypresentation/verifypresentation_service.go +++ b/pkg/service/verifypresentation/verifypresentation_service.go @@ -170,7 +170,8 @@ func (s *Service) VerifyPresentation( //nolint:funlen,gocognit if profile.Checks.Credential.LinkedDomain { st := time.Now() - err := s.vcVerifier.ValidateLinkedDomain(ctx, profile.SigningDID.DID) + err := s.checkLinkedDomain(ctx, credentials) + result.Checks = append(result.Checks, &Check{ Check: "linkedDomain", Error: err, @@ -297,6 +298,22 @@ func (s *Service) checkIssuerTrustList( return nil } +func (s *Service) checkLinkedDomain(ctx context.Context, credentials []*verifiable.Credential) error { + for _, cred := range credentials { + var issuerID string + + if cred.Contents().Issuer != nil { + issuerID = cred.Contents().Issuer.ID + } + + if err := s.vcVerifier.ValidateLinkedDomain(ctx, issuerID); err != nil { + return err + } + } + + return nil +} + func (s *Service) validatePresentationProof(targetPresentation interface{}, opts *Options) error { var final *verifiable.Presentation switch pres := targetPresentation.(type) { diff --git a/pkg/service/verifypresentation/verifypresentation_service_test.go b/pkg/service/verifypresentation/verifypresentation_service_test.go index 160969dad..e796a2faf 100644 --- a/pkg/service/verifypresentation/verifypresentation_service_test.go +++ b/pkg/service/verifypresentation/verifypresentation_service_test.go @@ -405,9 +405,7 @@ func TestService_VerifyPresentation(t *testing.T) { return } - if !reflect.DeepEqual(got, tt.want) { //nolint:govet - t.Errorf("VerifyPresentation() got = %v, want %v", got, tt.want) - } + assert.Equal(t, tt.want, got) }) } } @@ -998,6 +996,93 @@ func TestService_validateCredentialsStatus(t *testing.T) { } } +func TestService_checkLinkedDomain(t *testing.T) { + type fields struct { + getVcVerifier func(t *testing.T) vcVerifier + } + type args struct { + getCredentials func(t *testing.T) []*verifiable.Credential + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "OK", + fields: fields{ + getVcVerifier: func(t *testing.T) vcVerifier { + mockVerifier := NewMockVcVerifier(gomock.NewController(t)) + mockVerifier.EXPECT().ValidateLinkedDomain( + context.Background(), + "IssuerID", + ).Times(1).Return(nil) + return mockVerifier + }, + }, + args: args{ + getCredentials: func(t *testing.T) []*verifiable.Credential { + credContent := verifiable.CredentialContents{ + Types: []string{ + "VerifiableCredential", + "UniversityDegreeCredential", + }, + Issuer: &verifiable.Issuer{ID: "IssuerID"}, + } + + cred1, err := verifiable.CreateCredential(credContent, nil) + assert.NoError(t, err) + + return []*verifiable.Credential{cred1} + }, + }, + wantErr: false, + }, + { + name: "Error", + fields: fields{ + getVcVerifier: func(t *testing.T) vcVerifier { + mockVerifier := NewMockVcVerifier(gomock.NewController(t)) + mockVerifier.EXPECT().ValidateLinkedDomain( + context.Background(), + "", + ).Times(1).Return(errors.New("some error")) + return mockVerifier + }, + }, + args: args{ + getCredentials: func(t *testing.T) []*verifiable.Credential { + credContent := verifiable.CredentialContents{ + Types: []string{ + "VerifiableCredential", + "UniversityDegreeCredential", + }, + } + + cred1, err := verifiable.CreateCredential(credContent, nil) + assert.NoError(t, err) + + return []*verifiable.Credential{cred1} + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Service{ + vcVerifier: tt.fields.getVcVerifier(t), + } + if err := s.checkLinkedDomain( + context.Background(), + tt.args.getCredentials(t)); (err != nil) != tt.wantErr { + t.Errorf("checkLinkedDomain() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + func TestExtractCredentialStatus(t *testing.T) { s := &Service{}