diff --git a/statecheck/extract_bool_value.go b/statecheck/extract_bool_value.go new file mode 100644 index 000000000..1ce668424 --- /dev/null +++ b/statecheck/extract_bool_value.go @@ -0,0 +1,86 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package statecheck + +import ( + "context" + "fmt" + + tfjson "github.com/hashicorp/terraform-json" + "github.com/hashicorp/terraform-plugin-testing/tfjsonpath" +) + +var _ StateCheck = extractBoolValue{} + +type extractBoolValue struct { + resourceAddress string + attributePath tfjsonpath.Path + targetVar *bool +} + +func (e extractBoolValue) CheckState(ctx context.Context, req CheckStateRequest, resp *CheckStateResponse) { + var resource *tfjson.StateResource + + if req.State == nil { + resp.Error = fmt.Errorf("state is nil") + + return + } + + if req.State.Values == nil { + resp.Error = fmt.Errorf("state does not contain any state values") + + return + } + + if req.State.Values.RootModule == nil { + resp.Error = fmt.Errorf("state does not contain a root module") + + return + } + + for _, r := range req.State.Values.RootModule.Resources { + if e.resourceAddress == r.Address { + resource = r + + break + } + } + + if resource == nil { + resp.Error = fmt.Errorf("%s - Resource not found in state", e.resourceAddress) + + return + } + + result, err := tfjsonpath.Traverse(resource.AttributeValues, e.attributePath) + if err != nil { + resp.Error = err + + return + } + + if result == nil { + resp.Error = fmt.Errorf("nil: result for attribute '%s' in '%s'", e.attributePath, e.resourceAddress) + + return + } + + switch t := result.(type) { + case bool: + *e.targetVar = t + default: + resp.Error = fmt.Errorf("invalid type for attribute '%s' in '%s'. Expected: bool, Got: %T", e.attributePath, e.resourceAddress, t) + + return + } +} + +func ExtractBoolValue(resourceAddress string, attributePath tfjsonpath.Path, targetVar *bool) StateCheck { + return extractBoolValue{ + resourceAddress: resourceAddress, + attributePath: attributePath, + targetVar: targetVar, + } +} diff --git a/statecheck/extract_bool_value_test.go b/statecheck/extract_bool_value_test.go new file mode 100644 index 000000000..c2eec4a29 --- /dev/null +++ b/statecheck/extract_bool_value_test.go @@ -0,0 +1,153 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package statecheck_test + +import ( + "fmt" + "regexp" + "testing" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + + r "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/statecheck" + "github.com/hashicorp/terraform-plugin-testing/tfjsonpath" +) + +func TestExtractBoolValue_Basic(t *testing.T) { + t.Parallel() + + // targetVar will be set to the extracted value. + var targetVar bool + + r.Test(t, r.TestCase{ + ProviderFactories: map[string]func() (*schema.Provider, error){ + "test": func() (*schema.Provider, error) { //nolint:unparam // required signature + return testProvider(), nil + }, + }, + Steps: []r.TestStep{ + { + Config: `resource "test_resource" "one" { + bool_attribute = true + } + `, + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExtractBoolValue( + "test_resource.one", + tfjsonpath.New("bool_attribute"), + &targetVar, + ), + }, + }, + }, + }) + + t.Run("CheckTargetVar", func(t *testing.T) { + if err := testAccAssertBoolEquals(true, targetVar); err != nil { + t.Errorf("Error in testAccAssertBoolEquals: %v", err) + } + }) +} + +func TestExtractBoolValue_KnownValueWrongType(t *testing.T) { + t.Parallel() + + // targetVar will be set to the extracted value. + var targetVar bool + + r.Test(t, r.TestCase{ + ProviderFactories: map[string]func() (*schema.Provider, error){ + "test": func() (*schema.Provider, error) { //nolint:unparam // required signature + return testProvider(), nil + }, + }, + Steps: []r.TestStep{ + { + Config: `resource "test_resource" "one" { + float_attribute = 1.23 + } + `, + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExtractBoolValue( + "test_resource.one", + tfjsonpath.New("float_attribute"), + &targetVar, + ), + }, + ExpectError: regexp.MustCompile(`invalid type for attribute \'float_attribute\' in \'test_resource\.one\'. Expected: bool, Got: json\.Number`), + }, + }, + }) +} + +func TestExtractBoolValue_Null(t *testing.T) { + t.Parallel() + + // targetVar will be set to the extracted value. + var targetVar bool + + r.Test(t, r.TestCase{ + ProviderFactories: map[string]func() (*schema.Provider, error){ + "test": func() (*schema.Provider, error) { //nolint:unparam // required signature + return testProvider(), nil + }, + }, + Steps: []r.TestStep{ + { + Config: `resource "test_resource" "one" { + bool_attribute = null + } + `, + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExtractBoolValue( + "test_resource.one", + tfjsonpath.New("bool_attribute"), + &targetVar, + ), + }, + ExpectError: regexp.MustCompile(`nil: result for attribute \'bool_attribute\' in \'test_resource.one\'`), + }, + }, + }) +} + +func TestExtractBoolValue_ResourceNotFound(t *testing.T) { + t.Parallel() + + // targetVar will be set to the extracted value. + var targetVar bool + + r.Test(t, r.TestCase{ + ProviderFactories: map[string]func() (*schema.Provider, error){ + "test": func() (*schema.Provider, error) { //nolint:unparam // required signature + return testProvider(), nil + }, + }, + Steps: []r.TestStep{ + { + Config: `resource "test_resource" "one" { + bool_attribute = true + } + `, + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExtractBoolValue( + "test_resource.two", + tfjsonpath.New("bool_attribute"), + &targetVar, + ), + }, + ExpectError: regexp.MustCompile("test_resource.two - Resource not found in state"), + }, + }, + }) +} + +// testAccAssertBoolEquals compares the expected and target bool values. +func testAccAssertBoolEquals(expected bool, targetVar bool) error { + if targetVar != expected { + return fmt.Errorf("expected targetVar to be %v, got %v", expected, targetVar) + } + return nil +} diff --git a/statecheck/extract_string_value.go b/statecheck/extract_string_value.go new file mode 100644 index 000000000..ae3d63cf0 --- /dev/null +++ b/statecheck/extract_string_value.go @@ -0,0 +1,87 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package statecheck + +import ( + "context" + "fmt" + + tfjson "github.com/hashicorp/terraform-json" + "github.com/hashicorp/terraform-plugin-testing/tfjsonpath" +) + +var _ StateCheck = extractStringValue{} + +type extractStringValue struct { + resourceAddress string + attributePath tfjsonpath.Path + targetVar *string +} + +func (e extractStringValue) CheckState(ctx context.Context, req CheckStateRequest, resp *CheckStateResponse) { + var resource *tfjson.StateResource + + if req.State == nil { + resp.Error = fmt.Errorf("state is nil") + + return + } + + if req.State.Values == nil { + resp.Error = fmt.Errorf("state does not contain any state values") + + return + } + + if req.State.Values.RootModule == nil { + resp.Error = fmt.Errorf("state does not contain a root module") + + return + } + + for _, r := range req.State.Values.RootModule.Resources { + if e.resourceAddress == r.Address { + resource = r + + break + } + } + + if resource == nil { + resp.Error = fmt.Errorf("%s - Resource not found in state", e.resourceAddress) + + return + } + + result, err := tfjsonpath.Traverse(resource.AttributeValues, e.attributePath) + if err != nil { + resp.Error = err + + return + } + + if result == nil { + resp.Error = fmt.Errorf("nil: result for attribute '%s' in '%s'", e.attributePath, e.resourceAddress) + + return + } + + switch t := result.(type) { + case string: + *e.targetVar = t + return + default: + resp.Error = fmt.Errorf("invalid type for attribute '%s' in '%s'. Expected: string, Got: %T", e.attributePath, e.resourceAddress, t) + + return + } +} + +func ExtractStringValue(resourceAddress string, attributePath tfjsonpath.Path, targetVar *string) StateCheck { + return extractStringValue{ + resourceAddress: resourceAddress, + attributePath: attributePath, + targetVar: targetVar, + } +} diff --git a/statecheck/extract_string_value_test.go b/statecheck/extract_string_value_test.go new file mode 100644 index 000000000..9cb3eb8bc --- /dev/null +++ b/statecheck/extract_string_value_test.go @@ -0,0 +1,59 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package statecheck_test + +import ( + "fmt" + "testing" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + + r "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/statecheck" + "github.com/hashicorp/terraform-plugin-testing/tfjsonpath" +) + +func TestExtractStringValue_Basic(t *testing.T) { + t.Parallel() + + // targetVar will be set to the extracted value. + var targetVar string + + r.Test(t, r.TestCase{ + ProviderFactories: map[string]func() (*schema.Provider, error){ + "test": func() (*schema.Provider, error) { //nolint:unparam // required signature + return testProvider(), nil + }, + }, + Steps: []r.TestStep{ + { + Config: `resource "test_resource" "one" { + string_attribute = "str" + } + `, + ConfigStateChecks: []statecheck.StateCheck{ + statecheck.ExtractStringValue( + "test_resource.one", + tfjsonpath.New("string_attribute"), + &targetVar, + ), + }, + }, + }, + }) + + t.Run("CheckTargetVar", func(t *testing.T) { + if err := testAccAssertStringEquals("str", targetVar); err != nil { + t.Errorf("Error in testAccAssertBoolEquals: %v", err) + } + }) +} + +// testAccAssertStringEquals compares the expected and target string values. +func testAccAssertStringEquals(expected string, targetVar string) error { + if targetVar != expected { + return fmt.Errorf("expected targetVar to be %v, got %v", expected, targetVar) + } + return nil +}