diff --git a/api/http/common.go b/api/http/common.go index d2e3f49..ce3a1a4 100644 --- a/api/http/common.go +++ b/api/http/common.go @@ -51,7 +51,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { errors.Contains(err, ErrEmptyToken), errors.Contains(err, ErrInvalidQueryParams), errors.Contains(err, ErrValidation), - errors.Contains(err, ErrInvalidRequest): + errors.Contains(err, ErrInvalidRequest), + errors.Contains(err, ErrMissingCN): err = unwrap(err) w.WriteHeader(http.StatusBadRequest) @@ -63,7 +64,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(http.StatusUnprocessableEntity) case errors.Contains(err, certs.ErrNotFound), - errors.Contains(err, certs.ErrRootCANotFound): + errors.Contains(err, certs.ErrRootCANotFound), + errors.Contains(err, certs.ErrIntermediateCANotFound): err = unwrap(err) w.WriteHeader(http.StatusNotFound) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index aafdcb6..8d0fe65 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -91,7 +91,7 @@ func issueCertEndpoint(svc certs.Service) endpoint.Endpoint { return issueCertRes{}, err } - serialNumber, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs) + serialNumber, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options) if err != nil { return issueCertRes{}, err } @@ -219,3 +219,20 @@ func ocspEndpoint(svc certs.Service) endpoint.Endpoint { }, nil } } + +func generateCRLEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(crlReq) + if err := req.validate(); err != nil { + return crlRes{}, err + } + crlBytes, err := svc.GenerateCRL(ctx, req.certtype) + if err != nil { + return crlRes{}, err + } + + return crlRes{ + CrlBytes: crlBytes, + }, nil + } +} diff --git a/api/http/errors.go b/api/http/errors.go index bc243a7..e4123dc 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -29,4 +29,7 @@ var ( // ErrInvalidRequest indicates that the request is invalid. ErrInvalidRequest = errors.New("invalid request") + + // ErrMissingCN indicates missing common name. + ErrMissingCN = errors.New("missing common name") ) diff --git a/api/http/requests.go b/api/http/requests.go index f9940fd..f11f497 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -35,10 +35,22 @@ func (req viewReq) validate() error { return nil } +type crlReq struct { + certtype certs.CertType +} + +func (req crlReq) validate() error { + if req.certtype != certs.IntermediateCA { + return errors.Wrap(certs.ErrMalformedEntity, errors.New("invalid CA type")) + } + return nil +} + type issueCertReq struct { - entityID string `json:"-"` - TTL string `json:"ttl"` - IpAddrs []string `json:"ip_addresses"` + entityID string `json:"-"` + TTL string `json:"ttl"` + IpAddrs []string `json:"ip_addresses"` + Options certs.SubjectOptions `json:"options"` } func (req issueCertReq) validate() error { diff --git a/api/http/responses.go b/api/http/responses.go index 4ef1e90..5635b0b 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -135,8 +135,8 @@ func (res listCertsRes) Empty() bool { type viewCertRes struct { SerialNumber string `json:"serial_number"` - Certificate string `json:"certificate"` - Key string `json:"key"` + Certificate string `json:"certificate,omitempty"` + Key string `json:"key,omitempty"` Revoked bool `json:"revoked"` ExpiryTime time.Time `json:"expiry_time"` EntityID string `json:"entity_id"` @@ -154,6 +154,22 @@ func (res viewCertRes) Empty() bool { return false } +type crlRes struct { + CrlBytes []byte `json:"crl"` +} + +func (res crlRes) Code() int { + return http.StatusOK +} + +func (res crlRes) Headers() map[string]string { + return map[string]string{} +} + +func (res crlRes) Empty() bool { + return false +} + type ocspRes struct { template ocsp.Response signer crypto.Signer diff --git a/api/http/transport.go b/api/http/transport.go index f7f4fca..b8bc352 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -28,10 +28,13 @@ const ( offsetKey = "offset" limitKey = "limit" entityKey = "entity_id" + commonName = "common_name" + token = "token" ocspStatusParam = "force_status" entityIDParam = "entityID" defOffset = 0 defLimit = 10 + defType = 1 ) // MakeHandler returns a HTTP handler for API endpoints. @@ -91,6 +94,12 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http encodeOSCPResponse, opts..., ), "ocsp").ServeHTTP) + r.Get("/crl", otelhttp.NewHandler(kithttp.NewServer( + generateCRLEndpoint(svc), + decodeCRL, + EncodeResponse, + opts..., + ), "generate_crl").ServeHTTP) }) r.Get("/health", certs.Health("certs", instanceID)) @@ -106,8 +115,19 @@ func decodeView(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } +func decodeCRL(_ context.Context, r *http.Request) (interface{}, error) { + certType, err := readNumQuery(r, "", defType) + if err != nil { + return nil, err + } + req := crlReq{ + certtype: certs.CertType(certType), + } + return req, nil +} + func decodeDownloadCerts(_ context.Context, r *http.Request) (interface{}, error) { - token, err := readStringQuery(r, "token", "") + token, err := readStringQuery(r, token, "") if err != nil { return nil, err } @@ -140,13 +160,22 @@ func decodeIssueCert(_ context.Context, r *http.Request) (interface{}, error) { if err != nil { return nil, err } + cn, err := readStringQuery(r, commonName, "") + if err != nil { + return nil, err + } + if cn == "" { + return nil, ErrMissingCN + } req := issueCertReq{ entityID: chi.URLParam(r, entityIDParam), + Options: certs.SubjectOptions{ + CommonName: cn, + }, } if err := json.Unmarshal(body, &req); err != nil { return nil, errors.Wrap(ErrInvalidRequest, err) } - return req, nil } diff --git a/api/logging.go b/api/logging.go index 7fe48ba..e2078bd 100644 --- a/api/logging.go +++ b/api/logging.go @@ -73,7 +73,7 @@ func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri return lm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (serialNumber string, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (serialNumber string, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { @@ -82,7 +82,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string } lm.logger.Info(message) }(time.Now()) - return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs) + return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) { @@ -132,3 +132,15 @@ func (lm *loggingMiddleware) GetEntityID(ctx context.Context, serialNumber strin }(time.Now()) return lm.svc.GetEntityID(ctx, serialNumber) } + +func (lm *loggingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertType) (crl []byte, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method generate_crl took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.GenerateCRL(ctx, caType) +} diff --git a/api/metrics.go b/api/metrics.go index e2d7bda..ab64e87 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -61,12 +61,12 @@ func (mm *metricsMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri return mm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (string, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (string, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs) + return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { @@ -100,3 +100,11 @@ func (mm *metricsMiddleware) GetEntityID(ctx context.Context, serialNumber strin }(time.Now()) return mm.svc.GetEntityID(ctx, serialNumber) } + +func (mm *metricsMiddleware) GenerateCRL(ctx context.Context, caType certs.CertType) ([]byte, error) { + defer func(begin time.Time) { + mm.counter.With("method", "generate_crl").Add(1) + mm.latency.With("method", "generate_crl").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.GenerateCRL(ctx, caType) +} diff --git a/certs.go b/certs.go index bfbe54e..d8d68b8 100644 --- a/certs.go +++ b/certs.go @@ -16,6 +16,7 @@ type Certificate struct { Revoked bool `db:"revoked"` ExpiryTime time.Time `db:"expiry_time"` EntityID string `db:"entity_id"` + Type CertType `db:"type"` DownloadUrl string `db:"-"` } @@ -51,13 +52,16 @@ type Service interface { RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (string, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (string, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) // GetEntityID retrieves the entity ID for a certificate. GetEntityID(ctx context.Context, serialNumber string) (string, error) + + // GenerateCRL creates + GenerateCRL(ctx context.Context, caType CertType) ([]byte, error) } type Repository interface { @@ -72,4 +76,10 @@ type Repository interface { // ListCerts retrieves the certificates from the database while applying filters. ListCerts(ctx context.Context, pm PageMetadata) (CertificatePage, error) + + // GetCAs retrieves rootCA and intermediateCA from database. + GetCAs(ctx context.Context, caType ...CertType) ([]Certificate, error) + + // ListRevokedCerts retrieves revoked lists from database. + ListRevokedCerts(ctx context.Context) ([]Certificate, error) } diff --git a/certs_test.go b/certs_test.go index 57cd98a..23cef22 100644 --- a/certs_test.go +++ b/certs_test.go @@ -30,14 +30,19 @@ var invalidToken = "123" func TestIssueCert(t *testing.T) { cRepo := new(mocks.MockRepository) + repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) + repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) svc, err := certs.NewService(context.Background(), cRepo) require.NoError(t, err) + repoCall.Unset() + repoCall1.Unset() testCases := []struct { desc string backendId string ttl string err error + getCAErr error }{ { desc: "successful issue", @@ -56,10 +61,10 @@ func TestIssueCert(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(tc.err) - defer repoCall1.Unset() - _, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}) + _, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{}) require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + repoCall1.Unset() }) } } @@ -69,13 +74,12 @@ func TestRevokeCert(t *testing.T) { invalidSerialNumber := "invalid serial number" - listCall := cRepo.On("ListCerts", mock.Anything, mock.Anything, mock.Anything).Return(certs.CertificatePage{}, nil) - t.Cleanup(func() { - listCall.Unset() - }) - + repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) + repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) svc, err := certs.NewService(context.Background(), cRepo) require.NoError(t, err) + repoCall.Unset() + repoCall1.Unset() testCases := []struct { desc string @@ -122,8 +126,12 @@ func TestRevokeCert(t *testing.T) { func TestGetCertDownloadToken(t *testing.T) { cRepo := new(mocks.MockRepository) + repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) + repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) svc, err := certs.NewService(context.Background(), cRepo) require.NoError(t, err) + repoCall.Unset() + repoCall1.Unset() testCases := []struct { desc string @@ -152,8 +160,12 @@ func TestGetCert(t *testing.T) { validToken, err := jwtToken.SignedString([]byte(serialNumber)) require.NoError(t, err) + repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) + repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) svc, err := certs.NewService(context.Background(), cRepo) require.NoError(t, err) + repoCall.Unset() + repoCall1.Unset() testCases := []struct { desc string @@ -243,8 +255,12 @@ func TestRenewCert(t *testing.T) { }, &x509.Certificate{}, &testKey.PublicKey, testKey) require.NoError(t, err) + repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) + repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) svc, err := certs.NewService(context.Background(), cRepo) require.NoError(t, err) + repoCall.Unset() + repoCall1.Unset() testCases := []struct { desc string @@ -328,10 +344,15 @@ func TestRenewCert(t *testing.T) { } } -func TestService_GetEntityID(t *testing.T) { +func TestGetEntityID(t *testing.T) { cRepo := new(mocks.MockRepository) + + repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) + repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) svc, err := certs.NewService(context.Background(), cRepo) require.NoError(t, err) + repoCall.Unset() + repoCall1.Unset() ctx := context.Background() serialNumber := "1234567890" @@ -354,10 +375,15 @@ func TestService_GetEntityID(t *testing.T) { }) } -func TestService_ListCerts(t *testing.T) { +func TestListCerts(t *testing.T) { cRepo := new(mocks.MockRepository) + + repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) + repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) svc, err := certs.NewService(context.Background(), cRepo) require.NoError(t, err) + repoCall.Unset() + repoCall1.Unset() ctx := context.Background() pageMetadata := certs.PageMetadata{Limit: 10, Offset: 0, EntityID: "entity-123"} @@ -385,3 +411,83 @@ func TestService_ListCerts(t *testing.T) { assert.Empty(t, certPage) }) } + +func TestGenerateCRL(t *testing.T) { + cRepo := new(mocks.MockRepository) + + privateKey, _ := rsa.GenerateKey(rand.Reader, 2048) + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "Test CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 365), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + + repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{ + {Type: certs.RootCA, Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})}, + {Type: certs.IntermediateCA, Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})}, + }, nil) + repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) + svc, err := certs.NewService(context.Background(), cRepo) + require.NoError(t, err) + repoCall.Unset() + repoCall1.Unset() + + testCases := []struct { + desc string + caType certs.CertType + certs []certs.Certificate + repoErr error + err error + }{ + { + desc: "generate CRL with root CA", + caType: certs.RootCA, + certs: []certs.Certificate{ + {SerialNumber: "1", ExpiryTime: time.Now(), EntityID: "123"}, + {SerialNumber: "2", ExpiryTime: time.Now(), EntityID: "456"}, + }, + err: nil, + }, + { + desc: "generate CRL with intermediate CA", + caType: certs.IntermediateCA, + certs: []certs.Certificate{ + {SerialNumber: "3", ExpiryTime: time.Now()}, + }, + err: nil, + }, + { + desc: "invalid CA type", + caType: certs.CertType(999), + err: errors.New("invalid CA type"), + }, + { + desc: "ListRevokedCerts error", + caType: certs.RootCA, + repoErr: certs.ErrViewEntity, + err: certs.ErrViewEntity, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := cRepo.On("ListRevokedCerts", mock.Anything).Return(tc.certs, tc.repoErr) + _, err := svc.GenerateCRL(context.Background(), tc.caType) + if tc.err != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.err.Error()) + } else { + assert.NoError(t, err) + } + repoCall.Unset() + }) + } +} diff --git a/cli/certs.go b/cli/certs.go index ea11c26..ab5cbf9 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -130,12 +130,12 @@ var cmdCerts = []cobra.Command{ logUsageCmd(*cmd, cmd.Use) return } - cert, err := sdk.DownloadCert(args[1], args[0]) + certBundle, err := sdk.DownloadCert(args[1], args[0]) if err != nil { logErrorCmd(*cmd, err) return } - logJSONCmd(*cmd, cert) + logSaveCertFiles(*cmd, certBundle) }, }, { @@ -161,20 +161,31 @@ var cmdCerts = []cobra.Command{ func NewCertsCmd() *cobra.Command { var ttl string issueCmd := cobra.Command{ - Use: "issue '[\"\", \"\"]' [--ttl=8760h]", + Use: "issue '[\"\", \"\"] '{\"organization\":[\"organization_name\"]}' [--ttl=8760h]", Short: "Issue certificate", Long: `Issues a certificate for a given entity ID.`, Run: func(cmd *cobra.Command, args []string) { - if len(args) != 2 { + if len(args) < 3 || len(args) > 4 { logUsageCmd(*cmd, cmd.Use) return } var ipAddrs []string - if err := json.Unmarshal([]byte(args[1]), &ipAddrs); err != nil { + if err := json.Unmarshal([]byte(args[2]), &ipAddrs); err != nil { logErrorCmd(*cmd, err) return } - serial, err := sdk.IssueCert(args[0], ttl, ipAddrs) + + var option ctxsdk.Options + option.CommonName = args[1] + + if len(args) == 4 { + if err := json.Unmarshal([]byte(args[3]), &option); err != nil { + logErrorCmd(*cmd, err) + return + } + } + + serial, err := sdk.IssueCert(args[0], ttl, ipAddrs, option) if err != nil { logErrorCmd(*cmd, err) return diff --git a/cli/certs_test.go b/cli/certs_test.go index 76637bb..e562b08 100644 --- a/cli/certs_test.go +++ b/cli/certs_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "net/http" + "os" "strings" "testing" @@ -32,6 +33,7 @@ const ( var ( serialNumber = "39054620502613157373429341617471746606" id = "5b4c9ee3-e719-4a0a-9ee5-354932c5e6a4" + commonName = "test-name" extraArg = "extra-arg" ) @@ -56,6 +58,7 @@ func TestIssueCertCmd(t *testing.T) { desc: "issue cert successfully", args: []string{ id, + commonName, ipAddrs, }, logType: entityLog, @@ -66,7 +69,6 @@ func TestIssueCertCmd(t *testing.T) { args: []string{ id, ipAddrs, - extraArg, }, logType: usageLog, }, @@ -74,17 +76,29 @@ func TestIssueCertCmd(t *testing.T) { desc: "issue cert failed", args: []string{ id, + commonName, ipAddrs, }, sdkErr: errors.NewSDKErrorWithStatus(certs.ErrCreateEntity, http.StatusUnprocessableEntity), errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrCreateEntity, http.StatusUnprocessableEntity)), logType: errLog, }, + { + desc: "issue cert with 4 args", + args: []string{ + id, + commonName, + ipAddrs, + "{\"organization\":[\"organization_name\"]}", + }, + logType: entityLog, + serial: sdk.SerialNumber{SerialNumber: serialNumber}, + }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("IssueCert", mock.Anything, mock.Anything, mock.Anything).Return(tc.serial, tc.sdkErr) + sdkCall := sdkMock.On("IssueCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.serial, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{issueCmd}, tc.args...)...) switch tc.logType { case entityLog: @@ -375,7 +389,7 @@ func TestGetTokenCmd(t *testing.T) { } } -func TestRetrieveCertCmd(t *testing.T) { +func TestDownloadCertCmd(t *testing.T) { sdkMock := new(sdkmocks.MockSDK) cli.SetSDK(sdkMock) certCmd := cli.NewCertsCmd() @@ -387,18 +401,26 @@ func TestRetrieveCertCmd(t *testing.T) { args []string sdkErr errors.SDKError errLogMessage string + logMessage string logType outputLog + certBundle sdk.CertificateBundle }{ { - desc: "retrieve cert successfully", + desc: "download cert successfully", args: []string{ serialNumber, token, }, logType: entityLog, + certBundle: sdk.CertificateBundle{ + CA: []byte("ca"), + Certificate: []byte("certificate"), + PrivateKey: []byte("privatekey"), + }, + logMessage: "Saved ca.pem\nSaved cert.pem\nSaved key.pem\n\nAll certificate files have been saved successfully.\n", }, { - desc: "retrieve cert with invalid args", + desc: "download cert with invalid args", args: []string{ serialNumber, token, @@ -407,7 +429,7 @@ func TestRetrieveCertCmd(t *testing.T) { logType: usageLog, }, { - desc: "retrieve cert failed", + desc: "download cert failed", args: []string{ serialNumber, token, @@ -415,14 +437,22 @@ func TestRetrieveCertCmd(t *testing.T) { sdkErr: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity)), logType: errLog, + certBundle: sdk.CertificateBundle{}, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("RetrieveCert", mock.Anything, mock.Anything).Return([]byte{}, tc.sdkErr) + defer func() { + cleanupFiles(t, []string{"ca.pem", "cert.pem", "key.pem"}) + }() + sdkCall := sdkMock.On("DownloadCert", mock.Anything, mock.Anything).Return(tc.certBundle, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{downloadCmd}, tc.args...)...) switch tc.logType { + case entityLog: + assert.True(t, strings.Contains(out, "Saved key.pem"), fmt.Sprintf("%s invalid output: %s", tc.desc, out)) + assert.True(t, strings.Contains(out, "Saved cert.pem"), fmt.Sprintf("%s invalid output: %s", tc.desc, out)) + assert.True(t, strings.Contains(out, "Saved ca.pem"), fmt.Sprintf("%s invalid output: %s", tc.desc, out)) case usageLog: assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) case errLog: @@ -432,3 +462,12 @@ func TestRetrieveCertCmd(t *testing.T) { }) } } + +func cleanupFiles(t *testing.T, filenames []string) { + for _, filename := range filenames { + err := os.Remove(filename) + if err != nil && !os.IsNotExist(err) { + t.Logf("Failed to remove file %s: %v", filename, err) + } + } +} diff --git a/cli/utils.go b/cli/utils.go index 758d569..c7e922f 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -6,12 +6,18 @@ package cli import ( "encoding/json" "fmt" + "io/fs" + "os" + "path/filepath" + ctxsdk "github.com/absmach/certs/sdk" "github.com/fatih/color" "github.com/hokaccha/go-prettyjson" "github.com/spf13/cobra" ) +const fileMode = fs.FileMode(600) + var ( // Limit query parameter. Limit uint64 = 10 @@ -57,3 +63,36 @@ func logErrorCmd(cmd cobra.Command, err error) { func logOKCmd(cmd cobra.Command) { fmt.Fprintf(cmd.OutOrStdout(), "\n%s\n\n", color.BlueString("ok")) } + +func logSaveCertFiles(cmd cobra.Command, certBundle ctxsdk.CertificateBundle) { + files := map[string][]byte{ + "ca.pem": certBundle.CA, + "cert.pem": certBundle.Certificate, + "key.pem": certBundle.PrivateKey, + } + + for filename, content := range files { + err := saveToFile(filename, content) + if err != nil { + logErrorCmd(cmd, err) + return + } + fmt.Fprintf(cmd.OutOrStdout(), "Saved %s\n", filename) + } + fmt.Fprintf(cmd.OutOrStdout(), "\nAll certificate files have been saved successfully.\n") +} + +func saveToFile(filename string, content []byte) error { + cwd, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get current working directory: %w", err) + } + + filePath := filepath.Join(cwd, filename) + err = os.WriteFile(filePath, content, fileMode) + if err != nil { + return fmt.Errorf("failed to write file %s: %w", filename, err) + } + + return nil +} diff --git a/mocks/repository.go b/mocks/repository.go index eecefdd..2ef8365 100644 --- a/mocks/repository.go +++ b/mocks/repository.go @@ -73,6 +73,79 @@ func (_c *MockRepository_CreateCert_Call) RunAndReturn(run func(context.Context, return _c } +// GetCAs provides a mock function with given fields: ctx, caType +func (_m *MockRepository) GetCAs(ctx context.Context, caType ...certs.CertType) ([]certs.Certificate, error) { + _va := make([]interface{}, len(caType)) + for _i := range caType { + _va[_i] = caType[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetCAs") + } + + var r0 []certs.Certificate + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ...certs.CertType) ([]certs.Certificate, error)); ok { + return rf(ctx, caType...) + } + if rf, ok := ret.Get(0).(func(context.Context, ...certs.CertType) []certs.Certificate); ok { + r0 = rf(ctx, caType...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]certs.Certificate) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ...certs.CertType) error); ok { + r1 = rf(ctx, caType...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRepository_GetCAs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCAs' +type MockRepository_GetCAs_Call struct { + *mock.Call +} + +// GetCAs is a helper method to define mock.On call +// - ctx context.Context +// - caType ...certs.CertType +func (_e *MockRepository_Expecter) GetCAs(ctx interface{}, caType ...interface{}) *MockRepository_GetCAs_Call { + return &MockRepository_GetCAs_Call{Call: _e.mock.On("GetCAs", + append([]interface{}{ctx}, caType...)...)} +} + +func (_c *MockRepository_GetCAs_Call) Run(run func(ctx context.Context, caType ...certs.CertType)) *MockRepository_GetCAs_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]certs.CertType, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(certs.CertType) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockRepository_GetCAs_Call) Return(_a0 []certs.Certificate, _a1 error) *MockRepository_GetCAs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRepository_GetCAs_Call) RunAndReturn(run func(context.Context, ...certs.CertType) ([]certs.Certificate, error)) *MockRepository_GetCAs_Call { + _c.Call.Return(run) + return _c +} + // ListCerts provides a mock function with given fields: ctx, pm func (_m *MockRepository) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { ret := _m.Called(ctx, pm) @@ -130,6 +203,64 @@ func (_c *MockRepository_ListCerts_Call) RunAndReturn(run func(context.Context, return _c } +// ListRevokedCerts provides a mock function with given fields: ctx +func (_m *MockRepository) ListRevokedCerts(ctx context.Context) ([]certs.Certificate, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListRevokedCerts") + } + + var r0 []certs.Certificate + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]certs.Certificate, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []certs.Certificate); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]certs.Certificate) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRepository_ListRevokedCerts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRevokedCerts' +type MockRepository_ListRevokedCerts_Call struct { + *mock.Call +} + +// ListRevokedCerts is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockRepository_Expecter) ListRevokedCerts(ctx interface{}) *MockRepository_ListRevokedCerts_Call { + return &MockRepository_ListRevokedCerts_Call{Call: _e.mock.On("ListRevokedCerts", ctx)} +} + +func (_c *MockRepository_ListRevokedCerts_Call) Run(run func(ctx context.Context)) *MockRepository_ListRevokedCerts_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockRepository_ListRevokedCerts_Call) Return(_a0 []certs.Certificate, _a1 error) *MockRepository_ListRevokedCerts_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRepository_ListRevokedCerts_Call) RunAndReturn(run func(context.Context) ([]certs.Certificate, error)) *MockRepository_ListRevokedCerts_Call { + _c.Call.Return(run) + return _c +} + // RetrieveCert provides a mock function with given fields: ctx, serialNumber func (_m *MockRepository) RetrieveCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { ret := _m.Called(ctx, serialNumber) diff --git a/mocks/service.go b/mocks/service.go index f17589f..9eb3634 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -28,6 +28,65 @@ func (_m *MockService) EXPECT() *MockService_Expecter { return &MockService_Expecter{mock: &_m.Mock} } +// GenerateCRL provides a mock function with given fields: ctx, caType +func (_m *MockService) GenerateCRL(ctx context.Context, caType certs.CertType) ([]byte, error) { + ret := _m.Called(ctx, caType) + + if len(ret) == 0 { + panic("no return value specified for GenerateCRL") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, certs.CertType) ([]byte, error)); ok { + return rf(ctx, caType) + } + if rf, ok := ret.Get(0).(func(context.Context, certs.CertType) []byte); ok { + r0 = rf(ctx, caType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, certs.CertType) error); ok { + r1 = rf(ctx, caType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_GenerateCRL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GenerateCRL' +type MockService_GenerateCRL_Call struct { + *mock.Call +} + +// GenerateCRL is a helper method to define mock.On call +// - ctx context.Context +// - caType certs.CertType +func (_e *MockService_Expecter) GenerateCRL(ctx interface{}, caType interface{}) *MockService_GenerateCRL_Call { + return &MockService_GenerateCRL_Call{Call: _e.mock.On("GenerateCRL", ctx, caType)} +} + +func (_c *MockService_GenerateCRL_Call) Run(run func(ctx context.Context, caType certs.CertType)) *MockService_GenerateCRL_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(certs.CertType)) + }) + return _c +} + +func (_c *MockService_GenerateCRL_Call) Return(_a0 []byte, _a1 error) *MockService_GenerateCRL_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_GenerateCRL_Call) RunAndReturn(run func(context.Context, certs.CertType) ([]byte, error)) *MockService_GenerateCRL_Call { + _c.Call.Return(run) + return _c +} + // GetEntityID provides a mock function with given fields: ctx, serialNumber func (_m *MockService) GetEntityID(ctx context.Context, serialNumber string) (string, error) { ret := _m.Called(ctx, serialNumber) @@ -85,9 +144,9 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s return _c } -// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string) (string, error) { - ret := _m.Called(ctx, entityID, ttl, ipAddrs) +// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (string, error) { + ret := _m.Called(ctx, entityID, ttl, ipAddrs, option) if len(ret) == 0 { panic("no return value specified for IssueCert") @@ -95,17 +154,17 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 string var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string) (string, error)); ok { - return rf(ctx, entityID, ttl, ipAddrs) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) (string, error)); ok { + return rf(ctx, entityID, ttl, ipAddrs, option) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string) string); ok { - r0 = rf(ctx, entityID, ttl, ipAddrs) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) string); ok { + r0 = rf(ctx, entityID, ttl, ipAddrs, option) } else { r0 = ret.Get(0).(string) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string) error); ok { - r1 = rf(ctx, entityID, ttl, ipAddrs) + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions) error); ok { + r1 = rf(ctx, entityID, ttl, ipAddrs, option) } else { r1 = ret.Error(1) } @@ -123,13 +182,14 @@ type MockService_IssueCert_Call struct { // - entityID string // - ttl string // - ipAddrs []string -func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}) *MockService_IssueCert_Call { - return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs)} +// - option certs.SubjectOptions +func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}) *MockService_IssueCert_Call { + return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, option)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions)) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions)) }) return _c } @@ -139,7 +199,7 @@ func (_c *MockService_IssueCert_Call) Return(_a0 string, _a1 error) *MockService return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string) (string, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions) (string, error)) *MockService_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/postgres/certs.go b/postgres/certs.go index 0c8dcc5..cab1b81 100644 --- a/postgres/certs.go +++ b/postgres/certs.go @@ -44,8 +44,8 @@ func NewRepository(db postgres.Database) certs.Repository { // CreateLog creates computation log in the database. func (repo certsRepo) CreateCert(ctx context.Context, cert certs.Certificate) error { q := ` - INSERT INTO certs (serial_number, certificate, key, entity_id, revoked, expiry_time) - VALUES (:serial_number, :certificate, :key, :entity_id, :revoked, :expiry_time)` + INSERT INTO certs (serial_number, certificate, key, entity_id, revoked, expiry_time, type) + VALUES (:serial_number, :certificate, :key, :entity_id, :revoked, :expiry_time, :type)` _, err := repo.db.NamedExecContext(ctx, q, cert) if err != nil { return handleError(certs.ErrCreateEntity, err) @@ -55,7 +55,7 @@ func (repo certsRepo) CreateCert(ctx context.Context, cert certs.Certificate) er // RetrieveLog retrieves computation log from the database. func (repo certsRepo) RetrieveCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { - q := `SELECT * FROM certs WHERE serial_number = $1` + q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM certs WHERE serial_number = $1` var cert certs.Certificate if err := repo.db.QueryRowxContext(ctx, q, serialNumber).StructScan(&cert); err != nil { if err == sql.ErrNoRows { @@ -66,6 +66,56 @@ func (repo certsRepo) RetrieveCert(ctx context.Context, serialNumber string) (ce return cert, nil } +// GetCAs reterives rootCA and intermediateCA from database. +func (repo certsRepo) GetCAs(ctx context.Context, caType ...certs.CertType) ([]certs.Certificate, error) { + q := `SELECT serial_number, key, certificate, expiry_time, revoked, type FROM certs WHERE type = ANY($1)` + var certificates []certs.Certificate + + types := make([]string, 0, len(caType)) + for i, t := range caType { + types[i] = t.String() + } + + if len(types) == 0 { + types = []string{certs.RootCA.String(), certs.IntermediateCA.String()} + } + + rows, err := repo.db.QueryContext(ctx, q, types) + if err != nil { + return []certs.Certificate{}, handleError(certs.ErrViewEntity, err) + } + defer rows.Close() + + var certType string + for rows.Next() { + cert := &certs.Certificate{} + if err := rows.Scan( + &cert.SerialNumber, + &cert.Key, + &cert.Certificate, + &cert.ExpiryTime, + &cert.Revoked, + &certType, + ); err != nil { + return []certs.Certificate{}, errors.Wrap(certs.ErrViewEntity, err) + } + + crtType, err := certs.CertTypeFromString(certType) + if err != nil { + return []certs.Certificate{}, errors.Wrap(certs.ErrViewEntity, err) + } + cert.Type = crtType + + certificates = append(certificates, *cert) + } + + if err = rows.Err(); err != nil { + return []certs.Certificate{}, errors.Wrap(certs.ErrViewEntity, err) + } + + return certificates, nil +} + // UpdateLog updates computation log in the database. func (repo certsRepo) UpdateCert(ctx context.Context, cert certs.Certificate) error { q := `UPDATE certs SET certificate = :certificate, key = :key, revoked = :revoked, expiry_time = :expiry_time WHERE serial_number = :serial_number` @@ -85,13 +135,13 @@ func (repo certsRepo) UpdateCert(ctx context.Context, cert certs.Certificate) er func (repo certsRepo) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { q := `SELECT serial_number, revoked, expiry_time, entity_id FROM certs %s LIMIT :limit OFFSET :offset` - condition := `` + var condition string if pm.EntityID != "" { - condition = `WHERE entity_id = :entity_id` - q = fmt.Sprintf(q, condition) + condition = fmt.Sprintf(`WHERE entity_id = :entity_id AND type = '%s'`, certs.ClientCert.String()) } else { - q = fmt.Sprintf(q, condition) + condition = fmt.Sprintf(`WHERE type = '%s'`, certs.ClientCert.String()) } + q = fmt.Sprintf(q, condition) var certificates []certs.Certificate params := map[string]interface{}{ @@ -125,6 +175,30 @@ func (repo certsRepo) ListCerts(ctx context.Context, pm certs.PageMetadata) (cer }, nil } +func (repo certsRepo) ListRevokedCerts(ctx context.Context) ([]certs.Certificate, error) { + query := ` + SELECT serial_number, entity_id, expiry_time + FROM certs + WHERE revoked = true + ` + rows, err := repo.db.QueryContext(ctx, query) + if err != nil { + return nil, handleError(certs.ErrViewEntity, err) + } + defer rows.Close() + + var revokedCerts []certs.Certificate + for rows.Next() { + var cert certs.Certificate + if err := rows.Scan(&cert.SerialNumber, &cert.EntityID, &cert.ExpiryTime); err != nil { + return nil, handleError(certs.ErrViewEntity, err) + } + revokedCerts = append(revokedCerts, cert) + } + + return revokedCerts, nil +} + func (repo certsRepo) total(ctx context.Context, query string, params interface{}) (uint64, error) { rows, err := repo.db.NamedQueryContext(ctx, query, params) if err != nil { diff --git a/postgres/init.go b/postgres/init.go index 36e2468..9f52b7a 100644 --- a/postgres/init.go +++ b/postgres/init.go @@ -21,6 +21,7 @@ func Migration() *migrate.MemoryMigrationSource { revoked BOOLEAN, expiry_time TIMESTAMP, entity_id VARCHAR(36), + type TEXT CHECK (type IN ('RootCA', 'IntermediateCA', 'ClientCert')), PRIMARY KEY (serial_number) )`, }, diff --git a/sdk/certs_test.go b/sdk/certs_test.go index 4ef6d45..b106f29 100644 --- a/sdk/certs_test.go +++ b/sdk/certs_test.go @@ -25,6 +25,7 @@ const ( serialNum = "8e7a30c-bc9f-22de-ae67-1342bc139507" id = "c333e6f1-59bb-4c39-9e13-3a2766af8ba5" ttl = "10h" + commonName = "test" ) func setupCerts() (*httptest.Server, *mocks.MockService) { @@ -49,21 +50,23 @@ func TestIssueCert(t *testing.T) { ipAddr := []string{"192.128.101.82"} cases := []struct { - desc string - entityID string - ttl string - ipAddrs []string - svcresp string - svcerr error - err errors.SDKError - sdkCert sdk.Certificate + desc string + entityID string + ttl string + ipAddrs []string + commonName string + svcresp string + svcerr error + err errors.SDKError + sdkCert sdk.Certificate }{ { - desc: "IssueCert success", - entityID: id, - ttl: ttl, - ipAddrs: ipAddr, - svcresp: serialNum, + desc: "IssueCert success", + entityID: id, + ttl: ttl, + ipAddrs: ipAddr, + commonName: commonName, + svcresp: serialNum, sdkCert: sdk.Certificate{ SerialNumber: serialNum, }, @@ -71,28 +74,31 @@ func TestIssueCert(t *testing.T) { err: nil, }, { - desc: "IssueCert failure", - entityID: id, - ttl: ttl, - ipAddrs: ipAddr, - svcresp: "", - svcerr: certs.ErrCreateEntity, - err: errors.NewSDKErrorWithStatus(certs.ErrCreateEntity, http.StatusUnprocessableEntity), + desc: "IssueCert failure", + entityID: id, + ttl: ttl, + ipAddrs: ipAddr, + commonName: commonName, + svcresp: "", + svcerr: certs.ErrCreateEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrCreateEntity, http.StatusUnprocessableEntity), }, { - desc: "IssueCert with empty entityID", - entityID: `""`, - ttl: ttl, - ipAddrs: ipAddr, - svcresp: "", - svcerr: certs.ErrMalformedEntity, - err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + desc: "IssueCert with empty entityID", + entityID: `""`, + ttl: ttl, + ipAddrs: ipAddr, + commonName: commonName, + svcresp: "", + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), }, { - desc: "IssueCert with empty ipAddrs", - entityID: id, - ttl: ttl, - svcresp: serialNum, + desc: "IssueCert with empty ipAddrs", + entityID: id, + ttl: ttl, + commonName: commonName, + svcresp: serialNum, sdkCert: sdk.Certificate{ SerialNumber: serialNum, }, @@ -100,28 +106,39 @@ func TestIssueCert(t *testing.T) { err: nil, }, { - desc: "IssueCert with empty ttl", - entityID: id, - ttl: "", - ipAddrs: ipAddr, - svcresp: serialNum, + desc: "IssueCert with empty ttl", + entityID: id, + ttl: "", + ipAddrs: ipAddr, + commonName: commonName, + svcresp: serialNum, sdkCert: sdk.Certificate{ SerialNumber: serialNum, }, svcerr: nil, err: nil, }, + { + desc: "IssueCert with empty commonName", + entityID: id, + ttl: ttl, + ipAddrs: ipAddr, + commonName: "", + svcresp: "", + svcerr: httpapi.ErrMissingCN, + err: errors.NewSDKErrorWithStatus(httpapi.ErrMissingCN, http.StatusBadRequest), + }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("IssueCert", mock.Anything, tc.entityID, tc.ttl, tc.ipAddrs).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("IssueCert", mock.Anything, tc.entityID, tc.ttl, tc.ipAddrs, mock.Anything).Return(tc.svcresp, tc.svcerr) - resp, err := ctsdk.IssueCert(tc.entityID, tc.ttl, tc.ipAddrs) + resp, err := ctsdk.IssueCert(tc.entityID, tc.ttl, tc.ipAddrs, sdk.Options{CommonName: tc.commonName}) assert.Equal(t, tc.err, err) if tc.err == nil { assert.Equal(t, tc.sdkCert.SerialNumber, resp.SerialNumber) - ok := svcCall.Parent.AssertCalled(t, "IssueCert", mock.Anything, tc.entityID, tc.ttl, tc.ipAddrs) + ok := svcCall.Parent.AssertCalled(t, "IssueCert", mock.Anything, tc.entityID, tc.ttl, tc.ipAddrs, certs.SubjectOptions{CommonName: tc.commonName}) assert.True(t, ok) } svcCall.Unset() diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index 714d5c7..e95c94d 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -28,24 +28,22 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { } // DownloadCert provides a mock function with given fields: token, serialNumber -func (_m *MockSDK) DownloadCert(token string, serialNumber string) ([]byte, errors.SDKError) { +func (_m *MockSDK) DownloadCert(token string, serialNumber string) (sdk.CertificateBundle, errors.SDKError) { ret := _m.Called(token, serialNumber) if len(ret) == 0 { panic("no return value specified for DownloadCert") } - var r0 []byte + var r0 sdk.CertificateBundle var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string) ([]byte, errors.SDKError)); ok { + if rf, ok := ret.Get(0).(func(string, string) (sdk.CertificateBundle, errors.SDKError)); ok { return rf(token, serialNumber) } - if rf, ok := ret.Get(0).(func(string, string) []byte); ok { + if rf, ok := ret.Get(0).(func(string, string) sdk.CertificateBundle); ok { r0 = rf(token, serialNumber) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } + r0 = ret.Get(0).(sdk.CertificateBundle) } if rf, ok := ret.Get(1).(func(string, string) errors.SDKError); ok { @@ -78,19 +76,19 @@ func (_c *MockSDK_DownloadCert_Call) Run(run func(token string, serialNumber str return _c } -func (_c *MockSDK_DownloadCert_Call) Return(_a0 []byte, _a1 errors.SDKError) *MockSDK_DownloadCert_Call { +func (_c *MockSDK_DownloadCert_Call) Return(_a0 sdk.CertificateBundle, _a1 errors.SDKError) *MockSDK_DownloadCert_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockSDK_DownloadCert_Call) RunAndReturn(run func(string, string) ([]byte, errors.SDKError)) *MockSDK_DownloadCert_Call { +func (_c *MockSDK_DownloadCert_Call) RunAndReturn(run func(string, string) (sdk.CertificateBundle, errors.SDKError)) *MockSDK_DownloadCert_Call { _c.Call.Return(run) return _c } -// IssueCert provides a mock function with given fields: entityID, ttl, ipAddrs -func (_m *MockSDK) IssueCert(entityID string, ttl string, ipAddrs []string) (sdk.SerialNumber, errors.SDKError) { - ret := _m.Called(entityID, ttl, ipAddrs) +// IssueCert provides a mock function with given fields: entityID, ttl, ipAddrs, opts +func (_m *MockSDK) IssueCert(entityID string, ttl string, ipAddrs []string, opts sdk.Options) (sdk.SerialNumber, errors.SDKError) { + ret := _m.Called(entityID, ttl, ipAddrs, opts) if len(ret) == 0 { panic("no return value specified for IssueCert") @@ -98,17 +96,17 @@ func (_m *MockSDK) IssueCert(entityID string, ttl string, ipAddrs []string) (sdk var r0 sdk.SerialNumber var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string, []string) (sdk.SerialNumber, errors.SDKError)); ok { - return rf(entityID, ttl, ipAddrs) + if rf, ok := ret.Get(0).(func(string, string, []string, sdk.Options) (sdk.SerialNumber, errors.SDKError)); ok { + return rf(entityID, ttl, ipAddrs, opts) } - if rf, ok := ret.Get(0).(func(string, string, []string) sdk.SerialNumber); ok { - r0 = rf(entityID, ttl, ipAddrs) + if rf, ok := ret.Get(0).(func(string, string, []string, sdk.Options) sdk.SerialNumber); ok { + r0 = rf(entityID, ttl, ipAddrs, opts) } else { r0 = ret.Get(0).(sdk.SerialNumber) } - if rf, ok := ret.Get(1).(func(string, string, []string) errors.SDKError); ok { - r1 = rf(entityID, ttl, ipAddrs) + if rf, ok := ret.Get(1).(func(string, string, []string, sdk.Options) errors.SDKError); ok { + r1 = rf(entityID, ttl, ipAddrs, opts) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -127,13 +125,14 @@ type MockSDK_IssueCert_Call struct { // - entityID string // - ttl string // - ipAddrs []string -func (_e *MockSDK_Expecter) IssueCert(entityID interface{}, ttl interface{}, ipAddrs interface{}) *MockSDK_IssueCert_Call { - return &MockSDK_IssueCert_Call{Call: _e.mock.On("IssueCert", entityID, ttl, ipAddrs)} +// - opts sdk.Options +func (_e *MockSDK_Expecter) IssueCert(entityID interface{}, ttl interface{}, ipAddrs interface{}, opts interface{}) *MockSDK_IssueCert_Call { + return &MockSDK_IssueCert_Call{Call: _e.mock.On("IssueCert", entityID, ttl, ipAddrs, opts)} } -func (_c *MockSDK_IssueCert_Call) Run(run func(entityID string, ttl string, ipAddrs []string)) *MockSDK_IssueCert_Call { +func (_c *MockSDK_IssueCert_Call) Run(run func(entityID string, ttl string, ipAddrs []string, opts sdk.Options)) *MockSDK_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].([]string)) + run(args[0].(string), args[1].(string), args[2].([]string), args[3].(sdk.Options)) }) return _c } @@ -143,7 +142,7 @@ func (_c *MockSDK_IssueCert_Call) Return(_a0 sdk.SerialNumber, _a1 errors.SDKErr return _c } -func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string) (sdk.SerialNumber, errors.SDKError)) *MockSDK_IssueCert_Call { +func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string, sdk.Options) (sdk.SerialNumber, errors.SDKError)) *MockSDK_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/sdk/sdk.go b/sdk/sdk.go index a371ca3..29500c0 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -4,6 +4,7 @@ package sdk import ( + "archive/zip" "bytes" "crypto/tls" "encoding/json" @@ -40,11 +41,23 @@ const ( type ContentType string type PageMetadata struct { - Total uint64 `json:"total,omitempty"` - Offset uint64 `json:"offset,omitempty"` - Limit uint64 `json:"limit,omitempty"` - EntityID string `json:"entity_id,omitempty"` - Token string `json:"token,omitempty"` + Total uint64 `json:"total,omitempty"` + Offset uint64 `json:"offset,omitempty"` + Limit uint64 `json:"limit,omitempty"` + EntityID string `json:"entity_id,omitempty"` + Token string `json:"token,omitempty"` + CommonName string `json:"common_name,omitempty"` +} + +type Options struct { + CommonName string + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` } type SerialNumber struct { @@ -90,20 +103,26 @@ type mgSDK struct { curlFlag bool } +type CertificateBundle struct { + CA []byte `json:"ca"` + Certificate []byte `json:"certificate"` + PrivateKey []byte `json:"private_key"` +} + type SDK interface { // IssueCert issues a certificate for a thing required for mTLS. // // example: - // serial , _ := sdk.IssueCert("entityID", "10h", []string{"ipAddr1", "ipAddr2"}) + // serial , _ := sdk.IssueCert("entityID", "10h", []string{"ipAddr1", "ipAddr2"}, sdk.Options{CommonName: "commonName"}) // fmt.Println(serial) - IssueCert(entityID, ttl string, ipAddrs []string) (SerialNumber, errors.SDKError) + IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (SerialNumber, errors.SDKError) // DownloadCert returns a certificate given certificate ID // // example: - // cert, _ := sdk.DownloadCert("serialNumber", "download-token") - // fmt.Println(cert) - DownloadCert(token, serialNumber string) ([]byte, errors.SDKError) + // certBundle, _ := sdk.DownloadCert("serialNumber", "download-token") + // fmt.Println(certBundle) + DownloadCert(token, serialNumber string) (CertificateBundle, errors.SDKError) // RevokeCert revokes certificate for thing with thingID // @@ -148,23 +167,26 @@ type SDK interface { OCSP(serialNumber string) (*ocsp.Response, errors.SDKError) } -func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string) (SerialNumber, errors.SDKError) { +func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (SerialNumber, errors.SDKError) { r := certReq{ IpAddrs: ipAddrs, TTL: ttl, + Options: opts, } d, err := json.Marshal(r) if err != nil { return SerialNumber{}, errors.NewSDKError(err) } + url := fmt.Sprintf("%s/%s", issueCertEndpoint, entityID) - url := fmt.Sprintf("%s/%s/%s", sdk.certsURL, issueCertEndpoint, entityID) - + url, err = sdk.withQueryParams(sdk.certsURL, url, PageMetadata{CommonName: opts.CommonName}) + if err != nil { + return SerialNumber{}, errors.NewSDKError(err) + } _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusCreated) if sdkerr != nil { return SerialNumber{}, sdkerr } - var sn SerialNumber if err := json.Unmarshal(body, &sn); err != nil { return SerialNumber{}, errors.NewSDKError(err) @@ -173,20 +195,41 @@ func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string) (SerialNumber return sn, nil } -func (sdk mgSDK) DownloadCert(token, serialNumber string) ([]byte, errors.SDKError) { +func (sdk mgSDK) DownloadCert(token, serialNumber string) (CertificateBundle, errors.SDKError) { pm := PageMetadata{ Token: token, } url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/download", certsEndpoint, serialNumber), pm) if err != nil { - return []byte{}, errors.NewSDKError(err) + return CertificateBundle{}, errors.NewSDKError(err) } _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) if sdkerr != nil { - return []byte{}, sdkerr + return CertificateBundle{}, sdkerr } - return body, nil + zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + return CertificateBundle{}, errors.NewSDKError(err) + } + + var bundle CertificateBundle + for _, file := range zipReader.File { + fileContent, err := readZipFile(file) + if err != nil { + return CertificateBundle{}, errors.NewSDKError(err) + } + switch file.Name { + case "ca.pem": + bundle.CA = fileContent + case "cert.pem": + bundle.Certificate = fileContent + case "key.pem": + bundle.PrivateKey = fileContent + } + } + + return bundle, nil } func (sdk mgSDK) ViewCert(serialNumber string) (Certificate, errors.SDKError) { @@ -342,11 +385,24 @@ func (pm PageMetadata) query() (string, error) { if pm.Token != "" { q.Add("token", pm.Token) } + if pm.CommonName != "" { + q.Add("common_name", pm.CommonName) + } return q.Encode(), nil } +func readZipFile(file *zip.File) ([]byte, error) { + fc, err := file.Open() + if err != nil { + return nil, err + } + defer fc.Close() + return io.ReadAll(fc) +} + type certReq struct { IpAddrs []string `json:"ip_addresses"` TTL string `json:"ttl"` + Options Options `json:"options"` } diff --git a/service.go b/service.go index 1441451..55b336d 100644 --- a/service.go +++ b/service.go @@ -20,53 +20,129 @@ import ( ) const ( - CommonName = "AbstractMachines_Selfsigned_ca" - Organization = "AbstractMacines" - OrganizationalUnit = "AbstractMachines_ca" - Country = "Sirbea" - Province = "Sirbea" - Locality = "Sirbea" - StreetAddress = "Sirbea" - PostalCode = "Sirbea" - emailAddress = "info@abstractmachines.rs" - PrivateKeyBytes = 2048 - certValidityPeriod = time.Hour * 24 * 90 // 90 days + CommonName = "AbstractMachines_Selfsigned_ca" + Organization = "AbstractMacines" + OrganizationalUnit = "AbstractMachines_ca" + Country = "Sirbea" + Province = "Sirbea" + Locality = "Sirbea" + StreetAddress = "Sirbea" + PostalCode = "Sirbea" + emailAddress = "info@abstractmachines.rs" + PrivateKeyBytes = 2048 + RootCAValidityPeriod = time.Hour * 24 * 365 // 365 days + IntermediateCAVAlidityPeriod = time.Hour * 24 * 90 // 90 days + certValidityPeriod = time.Hour * 24 * 90 // 30 days + rCertExpiryThreshold = time.Hour * 24 * 30 // 30 days + iCertExpiryThreshold = time.Hour * 24 * 10 // 10 days ) +type CertType int + +const ( + RootCA CertType = iota + IntermediateCA + ClientCert +) + +const ( + Root = "RootCA" + Inter = "IntermediateCA" + Client = "ClientCert" + Unknown = "Unknown" +) + +func (c CertType) String() string { + switch c { + case RootCA: + return Root + case IntermediateCA: + return Inter + case ClientCert: + return Client + default: + return Unknown + } +} + +func CertTypeFromString(s string) (CertType, error) { + switch s { + case Root: + return RootCA, nil + case Inter: + return IntermediateCA, nil + case Client: + return ClientCert, nil + default: + return -1, errors.New("unknown cert type") + } +} + +type CA struct { + Type CertType + Certificate *x509.Certificate + PrivateKey *rsa.PrivateKey + SerialNumber string +} + var ( - serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 128) - ErrNotFound = errors.New("entity not found") - ErrConflict = errors.New("entity already exists") - ErrCreateEntity = errors.New("failed to create entity") - ErrViewEntity = errors.New("view entity failed") - ErrGetToken = errors.New("failed to get token") - ErrUpdateEntity = errors.New("update entity failed") - ErrMalformedEntity = errors.New("malformed entity specification") - ErrRootCANotFound = errors.New("root CA not found") - ErrCertExpired = errors.New("certificate expired before renewal") - ErrCertRevoked = errors.New("certificate has been revoked and cannot be renewed") + serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 128) + ErrNotFound = errors.New("entity not found") + ErrConflict = errors.New("entity already exists") + ErrCreateEntity = errors.New("failed to create entity") + ErrViewEntity = errors.New("view entity failed") + ErrGetToken = errors.New("failed to get token") + ErrUpdateEntity = errors.New("update entity failed") + ErrMalformedEntity = errors.New("malformed entity specification") + ErrRootCANotFound = errors.New("root CA not found") + ErrIntermediateCANotFound = errors.New("intermediate CA not found") + ErrCertExpired = errors.New("certificate expired before renewal") + ErrCertRevoked = errors.New("certificate has been revoked and cannot be renewed") + ErrCertInvalidType = errors.New("invalid cert type") ) +type SubjectOptions struct { + CommonName string + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` +} + type service struct { - repo Repository - rootCACert *x509.Certificate - rootCAKey *rsa.PrivateKey + repo Repository + rootCA *CA + intermediateCA *CA } var _ Service = (*service)(nil) func NewService(ctx context.Context, repo Repository) (Service, error) { - cert, key, err := generateRootCA() - if err != nil { - return &service{}, err + var svc service + svc.repo = repo + if err := svc.loadCACerts(ctx); err != nil { + return &svc, err } - svc := &service{ - repo: repo, - rootCACert: cert, - rootCAKey: key, + // check if root ca should be rotated + rotateRoot := svc.shouldRotateCA(RootCA) + if rotateRoot { + if err := svc.rotateCA(ctx, RootCA); err != nil { + return &svc, err + } } - return svc, nil + + rotateIntermediate := svc.shouldRotateCA(IntermediateCA) + if rotateIntermediate { + if err := svc.rotateCA(ctx, IntermediateCA); err != nil { + return &svc, err + } + } + + return &svc, nil } // issueCert generates and issues a certificate for a given backendID. @@ -74,7 +150,7 @@ func NewService(ctx context.Context, repo Repository) (Service, error) { // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (string, error) { +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (string, error) { privKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) if err != nil { return "", err @@ -85,8 +161,8 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ return "", err } - if s.rootCACert == nil || s.rootCAKey == nil { - return "", ErrRootCANotFound + if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil { + return "", ErrIntermediateCANotFound } // Parse the TTL if provided, otherwise use the default certValidityPeriod. @@ -100,30 +176,20 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ validity = certValidityPeriod } + subject := s.getSubject(options) + template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{Organization}, - OrganizationalUnit: []string{OrganizationalUnit}, - Country: []string{Country}, - Province: []string{Province}, - Locality: []string{Locality}, - StreetAddress: []string{StreetAddress}, - PostalCode: []string{PostalCode}, - CommonName: s.rootCACert.Subject.CommonName, - Names: s.rootCACert.Subject.Names, - ExtraNames: s.rootCACert.Subject.ExtraNames, - SerialNumber: serialNumber.String(), - }, + SerialNumber: serialNumber, + Subject: subject, NotBefore: time.Now(), NotAfter: time.Now().Add(validity), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, BasicConstraintsValid: true, - DNSNames: append(s.rootCACert.DNSNames, ipAddrs...), + DNSNames: append(s.intermediateCA.Certificate.DNSNames, ipAddrs...), } - certBytes, err := x509.CreateCertificate(rand.Reader, &template, s.rootCACert, &privKey.PublicKey, s.rootCAKey) + certBytes, err := x509.CreateCertificate(rand.Reader, &template, s.intermediateCA.Certificate, &privKey.PublicKey, s.intermediateCA.PrivateKey) if err != nil { return "", err } @@ -133,6 +199,7 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ SerialNumber: template.SerialNumber.String(), EntityID: entityID, ExpiryTime: template.NotAfter, + Type: ClientCert, } if err = s.repo.CreateCert(ctx, dbCert); err != nil { return "", errors.Wrap(ErrCreateEntity, err) @@ -172,7 +239,7 @@ func (s *service) RetrieveCert(ctx context.Context, token, serialNumber string) if err != nil { return Certificate{}, []byte{}, errors.Wrap(ErrViewEntity, err) } - return cert, pem.EncodeToMemory(&pem.Block{Bytes: s.rootCACert.Raw, Type: "CERTIFICATE"}), nil + return cert, pem.EncodeToMemory(&pem.Block{Bytes: s.intermediateCA.Certificate.Raw, Type: "CERTIFICATE"}), nil } func (s *service) ListCerts(ctx context.Context, pm PageMetadata) (CertificatePage, error) { @@ -238,10 +305,10 @@ func (s *service) RenewCert(ctx context.Context, serialNumber string) error { if err != nil { return err } - if s.rootCACert == nil || s.rootCAKey == nil { - return ErrRootCANotFound + if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil { + return ErrIntermediateCANotFound } - newCertBytes, err := x509.CreateCertificate(rand.Reader, oldCert, s.rootCACert, &privKey.PublicKey, s.rootCAKey) + newCertBytes, err := x509.CreateCertificate(rand.Reader, oldCert, s.intermediateCA.Certificate, &privKey.PublicKey, s.intermediateCA.PrivateKey) if err != nil { return err } @@ -264,14 +331,14 @@ func (s *service) OCSP(ctx context.Context, serialNumber string) (*Certificate, cert, err := s.repo.RetrieveCert(ctx, serialNumber) if err != nil { if errors.Contains(err, ErrNotFound) { - return nil, ocsp.Unknown, s.rootCACert, nil + return nil, ocsp.Unknown, s.intermediateCA.Certificate, nil } - return nil, ocsp.ServerFailed, s.rootCACert, err + return nil, ocsp.ServerFailed, s.intermediateCA.Certificate, err } if cert.Revoked { - return &cert, ocsp.Revoked, s.rootCACert, nil + return &cert, ocsp.Revoked, s.intermediateCA.Certificate, nil } - return &cert, ocsp.Good, s.rootCACert, nil + return &cert, ocsp.Good, s.intermediateCA.Certificate, nil } func (s *service) GetEntityID(ctx context.Context, serialNumber string) (string, error) { @@ -282,15 +349,73 @@ func (s *service) GetEntityID(ctx context.Context, serialNumber string) (string, return cert.EntityID, nil } -func generateRootCA() (*x509.Certificate, *rsa.PrivateKey, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) +func (s *service) GenerateCRL(ctx context.Context, caType CertType) ([]byte, error) { + var ca *CA + + switch caType { + case RootCA: + if s.rootCA == nil { + return nil, errors.New("root CA not initialized") + } + ca = s.rootCA + case IntermediateCA: + if s.intermediateCA == nil { + return nil, errors.New("intermediate CA not initialized") + } + ca = s.intermediateCA + default: + return nil, errors.New("invalid CA type") + } + + revokedCerts, err := s.repo.ListRevokedCerts(ctx) if err != nil { - return nil, nil, err + return nil, err + } + + revokedCertificates := make([]pkix.RevokedCertificate, len(revokedCerts)) + for i, cert := range revokedCerts { + serialNumber := new(big.Int) + serialNumber.SetString(cert.SerialNumber, 10) + revokedCertificates[i] = pkix.RevokedCertificate{ + SerialNumber: serialNumber, + RevocationTime: cert.ExpiryTime, + } + } + + // CRL valid for 24 hours + now := time.Now() + expiry := now.Add(24 * time.Hour) + + crlTemplate := &x509.RevocationList{ + Number: big.NewInt(time.Now().UnixNano()), + ThisUpdate: now, + NextUpdate: expiry, + RevokedCertificates: revokedCertificates, + } + + crlBytes, err := x509.CreateRevocationList(rand.Reader, crlTemplate, ca.Certificate, ca.PrivateKey) + if err != nil { + return nil, err + } + + pemBlock := &pem.Block{ + Type: "X509 CRL", + Bytes: crlBytes, + } + pemBytes := pem.EncodeToMemory(pemBlock) + + return pemBytes, nil +} + +func (s *service) generateRootCA(ctx context.Context) (*CA, error) { + rootKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return nil, err } serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) if err != nil { - return nil, nil, err + return nil, err } certTemplate := &x509.Certificate{ @@ -313,21 +438,270 @@ func generateRootCA() (*x509.Certificate, *rsa.PrivateKey, error) { }, }, NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour * 24), + NotAfter: time.Now().Add(RootCAValidityPeriod), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, BasicConstraintsValid: true, IsCA: true, } - certDER, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey) + certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &rootKey.PublicKey, rootKey) + if err != nil { + return nil, err + } + + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, err + } + + if err != s.saveCA(ctx, cert, rootKey, RootCA) { + return nil, err + } + + return &CA{ + Type: RootCA, + Certificate: cert, + PrivateKey: rootKey, + SerialNumber: cert.SerialNumber.String(), + }, nil +} + +func (s *service) saveCA(ctx context.Context, cert *x509.Certificate, privateKey *rsa.PrivateKey, CertType CertType) error { + dbCert := Certificate{ + Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}), + Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}), + SerialNumber: cert.SerialNumber.String(), + ExpiryTime: cert.NotAfter, + Type: CertType, + } + if err := s.repo.CreateCert(ctx, dbCert); err != nil { + return errors.Wrap(ErrCreateEntity, err) + } + return nil +} + +func (s *service) createIntermediateCA(ctx context.Context, rootCA *CA) (*CA, error) { + intermediateKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) if err != nil { - return nil, nil, err + return nil, err } - cert, err := x509.ParseCertificate(certDER) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { - return nil, nil, err + return nil, err } - return cert, privateKey, nil + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: CommonName, + Organization: []string{Organization}, + OrganizationalUnit: []string{OrganizationalUnit}, + Country: []string{Country}, + Province: []string{Province}, + Locality: []string{Locality}, + StreetAddress: []string{StreetAddress}, + PostalCode: []string{PostalCode}, + SerialNumber: serialNumber.String(), + ExtraNames: []pkix.AttributeTypeAndValue{ + { + Type: asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 1}, + Value: emailAddress, + }, + }, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(IntermediateCAVAlidityPeriod), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, rootCA.Certificate, &intermediateKey.PublicKey, rootCA.PrivateKey) + if err != nil { + return nil, err + } + + intermediateCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, err + } + + if err != s.saveCA(ctx, intermediateCert, intermediateKey, IntermediateCA) { + return nil, err + } + + intermediateCA := &CA{ + Type: IntermediateCA, + Certificate: intermediateCert, + PrivateKey: intermediateKey, + SerialNumber: serialNumber.String(), + } + + return intermediateCA, nil +} + +func (s *service) getSubject(options SubjectOptions) pkix.Name { + subject := pkix.Name{ + CommonName: options.CommonName, + } + + if len(options.Organization) > 0 { + subject.Organization = options.Organization + } + if len(options.OrganizationalUnit) > 0 { + subject.OrganizationalUnit = options.OrganizationalUnit + } + if len(options.Country) > 0 { + subject.Country = options.Country + } + if len(options.Province) > 0 { + subject.Province = options.Province + } + if len(options.Locality) > 0 { + subject.Locality = options.Locality + } + if len(options.StreetAddress) > 0 { + subject.StreetAddress = options.StreetAddress + } + if len(options.PostalCode) > 0 { + subject.PostalCode = options.PostalCode + } + + return subject +} + +func (s *service) rotateCA(ctx context.Context, ctype CertType) error { + switch ctype { + case RootCA: + certificates, err := s.repo.GetCAs(ctx) + if err != nil { + return err + } + for _, cert := range certificates { + if err := s.RevokeCert(ctx, cert.SerialNumber); err != nil { + return err + } + } + newRootCA, err := s.generateRootCA(ctx) + if err != nil { + return err + } + s.rootCA = newRootCA + newIntermediateCA, err := s.createIntermediateCA(ctx, newRootCA) + if err != nil { + return err + } + s.intermediateCA = newIntermediateCA + + case IntermediateCA: + certificates, err := s.repo.GetCAs(ctx, IntermediateCA) + if err != nil { + return err + } + for _, cert := range certificates { + if err := s.RevokeCert(ctx, cert.SerialNumber); err != nil { + return err + } + } + newIntermediateCA, err := s.createIntermediateCA(ctx, s.rootCA) + if err != nil { + return err + } + s.intermediateCA = newIntermediateCA + + default: + return ErrCertInvalidType + } + + return nil +} + +func (s *service) shouldRotateCA(ctype CertType) bool { + switch ctype { + case RootCA: + if s.rootCA == nil { + return true + } + now := time.Now() + + // Check if the certificate is expiring soon i.e., within 30 days. + if now.Add(rCertExpiryThreshold).After(s.rootCA.Certificate.NotAfter) { + return true + } + case IntermediateCA: + if s.intermediateCA == nil { + return true + } + now := time.Now() + + // Check if the certificate is expiring soon i.e., within 10 days. + if now.Add(iCertExpiryThreshold).After(s.intermediateCA.Certificate.NotAfter) { + return true + } + } + + return false +} + +func (s *service) loadCACerts(ctx context.Context) error { + certificates, err := s.repo.GetCAs(ctx) + if err != nil { + return err + } + + for _, c := range certificates { + if c.Type == RootCA { + rblock, _ := pem.Decode(c.Certificate) + if rblock == nil { + return errors.New("failed to parse certificate PEM") + } + + rootCert, err := x509.ParseCertificate(rblock.Bytes) + if err != nil { + return err + } + rkey, _ := pem.Decode(c.Key) + if rkey == nil { + return errors.New("failed to parse key PEM") + } + rootKey, err := x509.ParsePKCS1PrivateKey(rkey.Bytes) + if err != nil { + return err + } + s.rootCA = &CA{ + Type: c.Type, + Certificate: rootCert, + PrivateKey: rootKey, + SerialNumber: c.SerialNumber, + } + } + + iblock, _ := pem.Decode(c.Certificate) + if iblock == nil { + return errors.New("failed to parse certificate PEM") + } + if c.Type == IntermediateCA { + interCert, err := x509.ParseCertificate(iblock.Bytes) + if err != nil { + return err + } + ikey, _ := pem.Decode(c.Key) + if ikey == nil { + return errors.New("failed to parse key PEM") + } + interKey, err := x509.ParsePKCS1PrivateKey(ikey.Bytes) + if err != nil { + return err + } + s.intermediateCA = &CA{ + Type: c.Type, + Certificate: interCert, + PrivateKey: interKey, + SerialNumber: c.SerialNumber, + } + } + } + return nil } diff --git a/tracing/certs.go b/tracing/certs.go index 3470759..765b803 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -47,10 +47,10 @@ func (tm *tracingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri return tm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (string, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (string, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() - return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs) + return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (tm *tracingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { @@ -76,3 +76,9 @@ func (tm *tracingMiddleware) GetEntityID(ctx context.Context, serialNumber strin defer span.End() return tm.svc.GetEntityID(ctx, serialNumber) } + +func (tm *tracingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertType) ([]byte, error) { + ctx, span := tm.tracer.Start(ctx, "generate_crl") + defer span.End() + return tm.svc.GenerateCRL(ctx, caType) +}