diff --git a/api/logging.go b/api/logging.go index c249e4d..4e797ea 100644 --- a/api/logging.go +++ b/api/logging.go @@ -5,7 +5,6 @@ package api import ( "context" - "crypto/rsa" "crypto/x509" "fmt" "log/slog" @@ -86,7 +85,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s return lm.svc.RetrieveCAToken(ctx) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { diff --git a/api/metrics.go b/api/metrics.go index fcbb650..d9db09c 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -5,7 +5,6 @@ package api import ( "context" - "crypto/rsa" "crypto/x509" "time" @@ -72,7 +71,7 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error return mm.svc.RetrieveCAToken(ctx) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) diff --git a/certs.go b/certs.go index 1208ae6..8cc1a2f 100644 --- a/certs.go +++ b/certs.go @@ -158,7 +158,7 @@ type Service interface { RetrieveCAToken(ctx context.Context) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...*rsa.PrivateKey) (Certificate, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...any) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) diff --git a/mocks/service.go b/mocks/service.go index 84daa16..5e91bc7 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -12,8 +12,6 @@ import ( mock "github.com/stretchr/testify/mock" - rsa "crypto/rsa" - x509 "crypto/x509" ) @@ -262,14 +260,10 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s } // IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option, privKey -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { - _va := make([]interface{}, len(privKey)) - for _i := range privKey { - _va[_i] = privKey[_i] - } +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...interface{}) (certs.Certificate, error) { var _ca []interface{} _ca = append(_ca, ctx, entityID, ttl, ipAddrs, option) - _ca = append(_ca, _va...) + _ca = append(_ca, privKey...) ret := _m.Called(_ca...) if len(ret) == 0 { @@ -278,16 +272,16 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) (certs.Certificate, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) (certs.Certificate, error)); ok { return rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) certs.Certificate); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) certs.Certificate); ok { r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r0 = ret.Get(0).(certs.Certificate) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) error); ok { r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r1 = ret.Error(1) @@ -307,18 +301,18 @@ type MockService_IssueCert_Call struct { // - ttl string // - ipAddrs []string // - option certs.SubjectOptions -// - privKey ...*rsa.PrivateKey +// - privKey ...interface{} func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey ...interface{}) *MockService_IssueCert_Call { return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", append([]interface{}{ctx, entityID, ttl, ipAddrs, option}, privKey...)...)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...*rsa.PrivateKey)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...interface{})) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*rsa.PrivateKey, len(args)-5) + variadicArgs := make([]interface{}, len(args)-5) for i, a := range args[5:] { if a != nil { - variadicArgs[i] = a.(*rsa.PrivateKey) + variadicArgs[i] = a.(interface{}) } } run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), variadicArgs...) @@ -331,7 +325,7 @@ func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) * return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) (certs.Certificate, error)) *MockService_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index ce5cd68..ad87e77 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -638,9 +638,9 @@ func (_c *MockSDK_RevokeCert_Call) RunAndReturn(run func(string) errors.SDKError return _c } -// SignCSR provides a mock function with given fields: entityID, ttl, csr -func (_m *MockSDK) SignCSR(entityID string, ttl string, csr string) (sdk.Certificate, errors.SDKError) { - ret := _m.Called(entityID, ttl, csr) +// SignCSR provides a mock function with given fields: entityID, ttl, csr, privKey +func (_m *MockSDK) SignCSR(entityID string, ttl string, csr string, privKey string) (sdk.Certificate, errors.SDKError) { + ret := _m.Called(entityID, ttl, csr, privKey) if len(ret) == 0 { panic("no return value specified for SignCSR") @@ -648,17 +648,17 @@ func (_m *MockSDK) SignCSR(entityID string, ttl string, csr string) (sdk.Certifi var r0 sdk.Certificate var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string, string) (sdk.Certificate, errors.SDKError)); ok { - return rf(entityID, ttl, csr) + if rf, ok := ret.Get(0).(func(string, string, string, string) (sdk.Certificate, errors.SDKError)); ok { + return rf(entityID, ttl, csr, privKey) } - if rf, ok := ret.Get(0).(func(string, string, string) sdk.Certificate); ok { - r0 = rf(entityID, ttl, csr) + if rf, ok := ret.Get(0).(func(string, string, string, string) sdk.Certificate); ok { + r0 = rf(entityID, ttl, csr, privKey) } else { r0 = ret.Get(0).(sdk.Certificate) } - if rf, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { - r1 = rf(entityID, ttl, csr) + if rf, ok := ret.Get(1).(func(string, string, string, string) errors.SDKError); ok { + r1 = rf(entityID, ttl, csr, privKey) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -677,13 +677,14 @@ type MockSDK_SignCSR_Call struct { // - entityID string // - ttl string // - csr string -func (_e *MockSDK_Expecter) SignCSR(entityID interface{}, ttl interface{}, csr interface{}) *MockSDK_SignCSR_Call { - return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", entityID, ttl, csr)} +// - privKey string +func (_e *MockSDK_Expecter) SignCSR(entityID interface{}, ttl interface{}, csr interface{}, privKey interface{}) *MockSDK_SignCSR_Call { + return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", entityID, ttl, csr, privKey)} } -func (_c *MockSDK_SignCSR_Call) Run(run func(entityID string, ttl string, csr string)) *MockSDK_SignCSR_Call { +func (_c *MockSDK_SignCSR_Call) Run(run func(entityID string, ttl string, csr string, privKey string)) *MockSDK_SignCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(string), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -693,7 +694,7 @@ func (_c *MockSDK_SignCSR_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKError) return _c } -func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_SignCSR_Call { +func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_SignCSR_Call { _c.Call.Return(run) return _c } diff --git a/service.go b/service.go index 88de439..cfb1a4a 100644 --- a/service.go +++ b/service.go @@ -5,6 +5,7 @@ package certs import ( "context" + "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rand" @@ -88,17 +89,17 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...*rsa.PrivateKey) (Certificate, error) { - var privKey rsa.PrivateKey +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...any) (Certificate, error) { + var privKey any var err error if len(key) == 0 { pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - privKey = *pKey + privKey = pKey if err != nil { return Certificate{}, err } } else { - privKey = *key[0] + privKey = key[0] } serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { @@ -132,12 +133,37 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ DNSNames: append(s.intermediateCA.Certificate.DNSNames, ipAddrs...), } - certBytes, err := x509.CreateCertificate(rand.Reader, &template, s.intermediateCA.Certificate, &privKey.PublicKey, s.intermediateCA.PrivateKey) + var pubKey crypto.PublicKey + var privKeyBytes []byte + var privKeyType string + + switch key := privKey.(type) { + case *rsa.PrivateKey: + pubKey = key.Public() + privKeyBytes = x509.MarshalPKCS1PrivateKey(key) + privKeyType = "RSA PRIVATE KEY" + case *ecdsa.PrivateKey: + pubKey = key.Public() + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = "EC PRIVATE KEY" + case ed25519.PrivateKey: + pubKey = key.Public() + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = "PRIVATE KEY" + default: + return Certificate{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type")) + } + + if err != nil { + return Certificate{}, err + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, s.intermediateCA.Certificate, pubKey, s.intermediateCA.PrivateKey) if err != nil { return Certificate{}, err } dbCert := Certificate{ - Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(&privKey)}), + Key: pem.EncodeToMemory(&pem.Block{Type: privKeyType, Bytes: privKeyBytes}), Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), SerialNumber: template.SerialNumber.String(), EntityID: entityID, diff --git a/tracing/certs.go b/tracing/certs.go index 4a720db..b9fd6ce 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -5,7 +5,6 @@ package tracing import ( "context" - "crypto/rsa" "crypto/x509" "github.com/absmach/certs" @@ -54,7 +53,7 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error return tm.svc.RetrieveCAToken(ctx) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...)