Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsr committed Dec 21, 2023
1 parent 8bc208e commit de94320
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 64 deletions.
4 changes: 2 additions & 2 deletions x-pack/filebeat/input/salesforce/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ type cursorConfig struct {

func (c *config) Validate() error {
switch {
case !c.Auth.JWT.isEnabled() && !c.Auth.OAuth2.isEnabled():
case !c.Auth.OAuth2.JWTBearerFlow.isEnabled() && !c.Auth.OAuth2.UserPasswordFlow.isEnabled():
return errors.New("no auth provider enabled")
case c.Auth.JWT.isEnabled() && c.Auth.OAuth2.isEnabled():
case c.Auth.OAuth2.JWTBearerFlow.isEnabled() && c.Auth.OAuth2.UserPasswordFlow.isEnabled():
return errors.New("only one auth provider must be enabled")
case c.URL == "":
return errors.New("no instance url is configured")
Expand Down
28 changes: 15 additions & 13 deletions x-pack/filebeat/input/salesforce/config_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,28 @@

package salesforce

import (
"errors"
)
import "errors"

type authConfig struct {
OAuth2 *oAuth2Config `config:"oauth2"`
JWT *jwtConfig `config:"jwt"`
OAuth2 OAuth2 `config:"oauth2"`
}

type oAuth2Config struct {
type OAuth2 struct {
UserPasswordFlow *userPasswordFlow `config:"user_password_flow"`
JWTBearerFlow *jwtBearerFlow `config:"jwt_bearer_flow"`
}

type userPasswordFlow struct {
Enabled *bool `config:"enabled"`

ClientID string `config:"client.id"`
ClientSecret string `config:"client.secret"`
Password string `config:"password"`
TokenURL string `config:"token_url"`
User string `config:"user"`
Username string `config:"username"`
}

type jwtConfig struct {
type jwtBearerFlow struct {
Enabled *bool `config:"enabled"`

URL string `config:"url"`
Expand All @@ -33,12 +35,12 @@ type jwtConfig struct {
}

// isEnabled returns true if the `enable` field is set to true in the yaml.
func (o *oAuth2Config) isEnabled() bool {
func (o *userPasswordFlow) isEnabled() bool {
return o != nil && (o.Enabled != nil && *o.Enabled)
}

// Validate checks if oauth2 config is valid.
func (o *oAuth2Config) Validate() error {
func (o *userPasswordFlow) Validate() error {
if !o.isEnabled() {
return nil
}
Expand All @@ -50,7 +52,7 @@ func (o *oAuth2Config) Validate() error {
return errors.New("client.id must be provided")
case o.ClientSecret == "":
return errors.New("client.secret must be provided")
case o.User == "":
case o.Username == "":
return errors.New("user must be provided")
case o.Password == "":
return errors.New("password must be provided")
Expand All @@ -61,11 +63,11 @@ func (o *oAuth2Config) Validate() error {
}

// isEnabled returns true if the `enable` field is set to true in the yaml.
func (o *jwtConfig) isEnabled() bool {
func (o *jwtBearerFlow) isEnabled() bool {
return o != nil && (o.Enabled != nil && *o.Enabled)
}

func (o *jwtConfig) Validate() error {
func (o *jwtBearerFlow) Validate() error {
if !o.isEnabled() {
return nil
}
Expand Down
34 changes: 17 additions & 17 deletions x-pack/filebeat/input/salesforce/config_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ import (

func TestOAuth2Config(t *testing.T) {
tests := map[string]struct {
config oAuth2Config
config userPasswordFlow
wantErr error
}{
"auth disabled I": {config: oAuth2Config{}, wantErr: nil},
"auth disabled II": {config: oAuth2Config{Enabled: pointer(false)}, wantErr: nil},
"tokenURL missing": {config: oAuth2Config{Enabled: pointer(true), TokenURL: ""}, wantErr: errors.New("token_url must be provided")},
"clientID missing": {config: oAuth2Config{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: ""}, wantErr: errors.New("client.id must be provided")},
"clientSecret missing": {config: oAuth2Config{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: "xyz", ClientSecret: ""}, wantErr: errors.New("client.secret must be provided")},
"user missing": {config: oAuth2Config{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: "xyz", ClientSecret: "abc", User: ""}, wantErr: errors.New("user must be provided")},
"password missing": {config: oAuth2Config{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: "xyz", ClientSecret: "abc", User: "user", Password: ""}, wantErr: errors.New("password must be provided")},
"all present": {config: oAuth2Config{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: "xyz", ClientSecret: "abc", User: "user", Password: "pass"}, wantErr: nil},
"auth disabled I": {config: userPasswordFlow{}, wantErr: nil},
"auth disabled II": {config: userPasswordFlow{Enabled: pointer(false)}, wantErr: nil},
"tokenURL missing": {config: userPasswordFlow{Enabled: pointer(true), TokenURL: ""}, wantErr: errors.New("token_url must be provided")},
"clientID missing": {config: userPasswordFlow{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: ""}, wantErr: errors.New("client.id must be provided")},
"clientSecret missing": {config: userPasswordFlow{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: "xyz", ClientSecret: ""}, wantErr: errors.New("client.secret must be provided")},
"user missing": {config: userPasswordFlow{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: "xyz", ClientSecret: "abc", Username: ""}, wantErr: errors.New("user must be provided")},
"password missing": {config: userPasswordFlow{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: "xyz", ClientSecret: "abc", Username: "user", Password: ""}, wantErr: errors.New("password must be provided")},
"all present": {config: userPasswordFlow{Enabled: pointer(true), TokenURL: "https://salesforce.com", ClientID: "xyz", ClientSecret: "abc", Username: "user", Password: "pass"}, wantErr: nil},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
Expand All @@ -35,16 +35,16 @@ func TestOAuth2Config(t *testing.T) {

func TestJWTConfig(t *testing.T) {
tests := map[string]struct {
config jwtConfig
config jwtBearerFlow
wantErr error
}{
"auth disabled I": {config: jwtConfig{}, wantErr: nil},
"auth disabled II": {config: jwtConfig{Enabled: pointer(false)}, wantErr: nil},
"url missing": {config: jwtConfig{Enabled: pointer(true), URL: ""}, wantErr: errors.New("url must be provided")},
"clientID missing": {config: jwtConfig{Enabled: pointer(true), URL: "https://salesforce.com", ClientID: ""}, wantErr: errors.New("client.id must be provided")},
"clientUsername missing": {config: jwtConfig{Enabled: pointer(true), URL: "https://salesforce.com", ClientID: "xyz", ClientUsername: ""}, wantErr: errors.New("client.username must be provided")},
"clientKeyPath missing": {config: jwtConfig{Enabled: pointer(true), URL: "https://salesforce.com", ClientID: "xyz", ClientUsername: "abc", ClientKeyPath: ""}, wantErr: errors.New("client.key_path must be provided")},
"all present": {config: jwtConfig{Enabled: pointer(true), URL: "https://salesforce.com", ClientID: "xyz", ClientUsername: "abc", ClientKeyPath: "def"}, wantErr: nil},
"auth disabled I": {config: jwtBearerFlow{}, wantErr: nil},
"auth disabled II": {config: jwtBearerFlow{Enabled: pointer(false)}, wantErr: nil},
"url missing": {config: jwtBearerFlow{Enabled: pointer(true), URL: ""}, wantErr: errors.New("url must be provided")},
"clientID missing": {config: jwtBearerFlow{Enabled: pointer(true), URL: "https://salesforce.com", ClientID: ""}, wantErr: errors.New("client.id must be provided")},
"clientUsername missing": {config: jwtBearerFlow{Enabled: pointer(true), URL: "https://salesforce.com", ClientID: "xyz", ClientUsername: ""}, wantErr: errors.New("client.username must be provided")},
"clientKeyPath missing": {config: jwtBearerFlow{Enabled: pointer(true), URL: "https://salesforce.com", ClientID: "xyz", ClientUsername: "abc", ClientKeyPath: ""}, wantErr: errors.New("client.key_path must be provided")},
"all present": {config: jwtBearerFlow{Enabled: pointer(true), URL: "https://salesforce.com", ClientID: "xyz", ClientUsername: "abc", ClientKeyPath: "def"}, wantErr: nil},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
Expand Down
34 changes: 19 additions & 15 deletions x-pack/filebeat/input/salesforce/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@ func TestValidate(t *testing.T) {
"no auth provider enabled (no password or jwt)": {
inputCfg: config{
Auth: &authConfig{
OAuth2: &oAuth2Config{},
JWT: &jwtConfig{},
OAuth2: OAuth2{
UserPasswordFlow: &userPasswordFlow{},
JWTBearerFlow: &jwtBearerFlow{},
},
},
},
wantErr: errors.New("no auth provider enabled"),
},
"only one auth provider is allowed (either password or jwt)": {
inputCfg: config{
Auth: &authConfig{
OAuth2: &oAuth2Config{Enabled: pointer(true)},
JWT: &jwtConfig{Enabled: pointer(true)},
OAuth2: OAuth2{
UserPasswordFlow: &userPasswordFlow{Enabled: pointer(true)},
JWTBearerFlow: &jwtBearerFlow{Enabled: pointer(true)},
},
},
},
wantErr: errors.New("only one auth provider must be enabled"),
Expand All @@ -40,8 +44,8 @@ func TestValidate(t *testing.T) {
inputCfg: config{
URL: "",
Auth: &authConfig{
OAuth2: &oAuth2Config{
Enabled: pointer(true),
OAuth2: OAuth2{
UserPasswordFlow: &userPasswordFlow{Enabled: pointer(true)},
},
},
},
Expand All @@ -52,12 +56,12 @@ func TestValidate(t *testing.T) {
EventMonitoringMethod: &EventMonitoringMethod{},
URL: "https://some-dummy-subdomain.salesforce.com/services/oauth2/token",
Auth: &authConfig{
OAuth2: &oAuth2Config{
Enabled: pointer(true),
OAuth2: OAuth2{
UserPasswordFlow: &userPasswordFlow{Enabled: pointer(true)},
},
},
},
wantErr: errors.New(`at least one of "data_collection_method.event_log_file.enabled" or "data_collection_method.object.enabled" must be set to true`),
wantErr: errors.New(`at least one of "event_monitoring_method.event_log_file.enabled" or "event_monitoring_method.object.enabled" must be set to true`),
},
"invalid elf interval (1h)": {
inputCfg: config{
Expand All @@ -69,8 +73,8 @@ func TestValidate(t *testing.T) {
},
URL: "https://some-dummy-subdomain.salesforce.com/services/oauth2/token",
Auth: &authConfig{
OAuth2: &oAuth2Config{
Enabled: pointer(true),
OAuth2: OAuth2{
UserPasswordFlow: &userPasswordFlow{Enabled: pointer(true)},
},
},
},
Expand All @@ -86,8 +90,8 @@ func TestValidate(t *testing.T) {
},
URL: "https://some-dummy-subdomain.salesforce.com/services/oauth2/token",
Auth: &authConfig{
OAuth2: &oAuth2Config{
Enabled: pointer(true),
OAuth2: OAuth2{
UserPasswordFlow: &userPasswordFlow{Enabled: pointer(true)},
},
},
},
Expand All @@ -104,8 +108,8 @@ func TestValidate(t *testing.T) {
},
URL: "https://some-dummy-subdomain.salesforce.com/services/oauth2/token",
Auth: &authConfig{
OAuth2: &oAuth2Config{
Enabled: pointer(true),
OAuth2: OAuth2{
UserPasswordFlow: &userPasswordFlow{Enabled: pointer(true)},
},
},
},
Expand Down
20 changes: 10 additions & 10 deletions x-pack/filebeat/input/salesforce/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ func getSFDCConfig(cfg *config) (*sfdc.Configuration, error) {
)

switch {
case cfg.Auth.JWT.isEnabled():
pemBytes, err := os.ReadFile(cfg.Auth.JWT.ClientKeyPath)
case cfg.Auth.OAuth2.JWTBearerFlow.isEnabled():
pemBytes, err := os.ReadFile(cfg.Auth.OAuth2.JWTBearerFlow.ClientKeyPath)
if err != nil {
return nil, fmt.Errorf("problem with client key path for JWT auth: %w", err)
}
Expand All @@ -370,23 +370,23 @@ func getSFDCConfig(cfg *config) (*sfdc.Configuration, error) {
}

passCreds := credentials.JwtCredentials{
URL: cfg.Auth.JWT.URL,
ClientId: cfg.Auth.JWT.ClientID,
ClientUsername: cfg.Auth.JWT.ClientUsername,
URL: cfg.Auth.OAuth2.JWTBearerFlow.URL,
ClientId: cfg.Auth.OAuth2.JWTBearerFlow.ClientID,
ClientUsername: cfg.Auth.OAuth2.JWTBearerFlow.ClientUsername,
ClientKey: signKey,
}

creds, err = credentials.NewJWTCredentials(passCreds)
if err != nil {
return nil, fmt.Errorf("problem with credentials: %w", err)
}
case cfg.Auth.OAuth2.isEnabled():
case cfg.Auth.OAuth2.UserPasswordFlow.isEnabled():
passCreds := credentials.PasswordCredentials{
URL: cfg.URL,
Username: cfg.Auth.OAuth2.User,
Password: cfg.Auth.OAuth2.Password,
ClientID: cfg.Auth.OAuth2.ClientID,
ClientSecret: cfg.Auth.OAuth2.ClientSecret,
Username: cfg.Auth.OAuth2.UserPasswordFlow.Username,
Password: cfg.Auth.OAuth2.UserPasswordFlow.Password,
ClientID: cfg.Auth.OAuth2.UserPasswordFlow.ClientID,
ClientSecret: cfg.Auth.OAuth2.UserPasswordFlow.ClientSecret,
}

creds, err = credentials.NewPasswordCredentials(passCreds)
Expand Down
4 changes: 2 additions & 2 deletions x-pack/filebeat/input/salesforce/input_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ func TestInputManager(t *testing.T) {
"url": "https://salesforce.com",
"version": 46,
"auth": &authConfig{
JWT: &jwtConfig{
OAuth2: OAuth2{JWTBearerFlow: &jwtBearerFlow{
Enabled: pointer(true),
URL: "https://salesforce.com",
ClientID: "xyz",
ClientUsername: "xyz",
ClientKeyPath: "xyz",
},
}},
},
"event_monitoring_method": &EventMonitoringMethod{
Object: EventMonitoringConfig{Enabled: pointer(true), Interval: 4},
Expand Down
6 changes: 3 additions & 3 deletions x-pack/filebeat/input/salesforce/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,19 @@ func TestInput(t *testing.T) {
wantErr bool
}{
{
name: "data_collection_method_object_with_default_query_only",
name: "event_monitoring_method_object_with_default_query_only",
setupServer: newTestServer(httptest.NewServer),
baseConfig: map[string]interface{}{
"auth.oauth2": map[string]interface{}{
"enabled": true,
"enabled": pointer(true),
"client.id": "clientid",
"client.secret": "clientsecret",
"token_url": "https://instance_id.develop.my.salesforce.com/services/oauth2/token",
"user": "username",
"password": "password",
},
"version": 56,
"data_collection_method": map[string]interface{}{
"event_monitoring_method": map[string]interface{}{
"object": map[string]interface{}{
"interval": "5m",
"enabled": true,
Expand Down
3 changes: 1 addition & 2 deletions x-pack/filebeat/input/salesforce/value_tpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ func (t *valueTpl) Unpack(in string) error {
}

func parseRFC3339Timestamp(s string) string {
now := timeNow().UTC()
_, err := time.Parse(time.RFC3339, s)
if err != nil {
return now.Format(time.RFC3339)
return ""
}
return s
}
Expand Down

0 comments on commit de94320

Please sign in to comment.