From 72cdb1a5551ca6d144adef894c114790c221608e Mon Sep 17 00:00:00 2001
From: Derek Brown <derek@allderek.com>
Date: Mon, 11 Mar 2024 10:38:05 -0700
Subject: [PATCH] Fix ALL + GRANT issue (#121)

---
 mysql/resource_grant.go      | 35 ++++++++++-----
 mysql/resource_grant_test.go | 87 +++++++++++++++++++-----------------
 2 files changed, 72 insertions(+), 50 deletions(-)

diff --git a/mysql/resource_grant.go b/mysql/resource_grant.go
index 5a7e500d4..60a4ca7c1 100644
--- a/mysql/resource_grant.go
+++ b/mysql/resource_grant.go
@@ -176,20 +176,33 @@ func (t *TablePrivilegeGrant) SQLGrantStatement() string {
 	return stmtSql
 }
 
+// containsAllPrivilege returns true if the privileges list contains an ALL PRIVILEGES grant
+// this is used because there is special case behavior for ALL PRIVILEGES grants. In particular,
+// if a user has ALL PRIVILEGES, we _cannot_ revoke ALL PRIVILEGES, GRANT OPTION because this is
+// invalid syntax.
+// See: https://github.com/petoju/terraform-provider-mysql/issues/120
+func containsAllPrivilege(privileges []string) bool {
+	for _, p := range privileges {
+		if kReAllPrivileges.MatchString(p) {
+			return true
+		}
+	}
+	return false
+}
+
 func (t *TablePrivilegeGrant) SQLRevokeStatement() string {
 	privs := t.Privileges
-	if t.Grant {
+	if t.Grant && !containsAllPrivilege(privs) {
 		privs = append(privs, "GRANT OPTION")
 	}
 	return fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(privs, ", "), t.GetDatabase(), t.GetTable(), t.UserOrRole.SQLString())
 }
 
 func (t *TablePrivilegeGrant) SQLPartialRevokePrivilegesStatement(privilegesToRevoke []string) string {
-	if t.Grant {
+	if t.Grant && !containsAllPrivilege(privilegesToRevoke) {
 		privilegesToRevoke = append(privilegesToRevoke, "GRANT OPTION")
 	}
-	stmt := fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(privilegesToRevoke, ", "), t.GetDatabase(), t.GetTable(), t.UserOrRole.SQLString())
-	return stmt
+	return fmt.Sprintf("REVOKE %s ON %s.%s FROM %s", strings.Join(privilegesToRevoke, ", "), t.GetDatabase(), t.GetTable(), t.UserOrRole.SQLString())
 }
 
 type ProcedurePrivilegeGrant struct {
@@ -246,7 +259,7 @@ func (t *ProcedurePrivilegeGrant) SQLGrantStatement() string {
 
 func (t *ProcedurePrivilegeGrant) SQLRevokeStatement() string {
 	privs := t.Privileges
-	if t.Grant {
+	if t.Grant && !containsAllPrivilege(privs) {
 		privs = append(privs, "GRANT OPTION")
 	}
 	stmt := fmt.Sprintf("REVOKE %s ON %s %s.%s FROM %s", strings.Join(privs, ", "), t.ObjectT, t.GetDatabase(), t.GetCallableName(), t.UserOrRole.SQLString())
@@ -255,11 +268,10 @@ func (t *ProcedurePrivilegeGrant) SQLRevokeStatement() string {
 
 func (t *ProcedurePrivilegeGrant) SQLPartialRevokePrivilegesStatement(privilegesToRevoke []string) string {
 	privs := privilegesToRevoke
-	if t.Grant {
+	if t.Grant && !containsAllPrivilege(privilegesToRevoke) {
 		privs = append(privs, "GRANT OPTION")
 	}
-	stmt := fmt.Sprintf("REVOKE %s ON %s %s.%s FROM %s", strings.Join(privs, ", "), t.ObjectT, t.GetDatabase(), t.GetCallableName(), t.UserOrRole.SQLString())
-	return stmt
+	return fmt.Sprintf("REVOKE %s ON %s %s.%s FROM %s", strings.Join(privs, ", "), t.ObjectT, t.GetDatabase(), t.GetCallableName(), t.UserOrRole.SQLString())
 }
 
 type RoleGrant struct {
@@ -1061,14 +1073,17 @@ func normalizeColumnOrder(perm string) string {
 	return fmt.Sprintf("%s(%s)", precursor, partsTogether)
 }
 
+var kReAllPrivileges = regexp.MustCompile(`ALL ?(PRIVILEGES)?`)
+
 func normalizePerms(perms []string) []string {
 	ret := []string{}
 	for _, perm := range perms {
 		// Remove leading and trailing backticks and spaces
 		permNorm := strings.Trim(perm, "` ")
-
 		permUcase := strings.ToUpper(permNorm)
-		if permUcase == "ALL" || permUcase == "ALLPRIVILEGES" {
+
+		// Normalize ALL and ALLPRIVILEGES to ALL PRIVILEGES
+		if kReAllPrivileges.MatchString(permUcase) {
 			permUcase = "ALL PRIVILEGES"
 		}
 		permSortedColumns := normalizeColumnOrder(permUcase)
diff --git a/mysql/resource_grant_test.go b/mysql/resource_grant_test.go
index 470172608..77931bfbe 100644
--- a/mysql/resource_grant_test.go
+++ b/mysql/resource_grant_test.go
@@ -44,38 +44,6 @@ func TestAccGrant(t *testing.T) {
 	})
 }
 
-func TestAccGrantWithGrantOption(t *testing.T) {
-	dbName := fmt.Sprintf("tf-test-%d", rand.Intn(100))
-	resource.Test(t, resource.TestCase{
-		PreCheck:     func() { testAccPreCheck(t) },
-		Providers:    testAccProviders,
-		CheckDestroy: testAccGrantCheckDestroy,
-		Steps: []resource.TestStep{
-			{
-				Config: testAccGrantConfigBasic(dbName),
-				Check: resource.ComposeTestCheckFunc(
-					testAccPrivilege("mysql_grant.test", "SELECT", true, false),
-					resource.TestCheckResourceAttr("mysql_grant.test", "grant", "false"),
-				),
-			},
-			{
-				Config: testAccGrantConfigBasicWithGrant(dbName),
-				Check: resource.ComposeTestCheckFunc(
-					testAccPrivilege("mysql_grant.test", "SELECT", true, true),
-					resource.TestCheckResourceAttr("mysql_grant.test", "grant", "true"),
-				),
-			},
-			{
-				Config: testAccGrantConfigBasic(dbName),
-				Check: resource.ComposeTestCheckFunc(
-					testAccPrivilege("mysql_grant.test", "SELECT", true, false),
-					resource.TestCheckResourceAttr("mysql_grant.test", "grant", "false"),
-				),
-			},
-		},
-	})
-}
-
 func TestAccRevokePrivRefresh(t *testing.T) {
 	dbName := fmt.Sprintf("tf-test-%d", rand.Intn(100))
 
@@ -204,7 +172,7 @@ func TestAccGrantComplex(t *testing.T) {
 				),
 			},
 			{
-				Config: testAccGrantConfigWithPrivs(dbName, `"SELECT (c1, c2)"`),
+				Config: testAccGrantConfigWithPrivs(dbName, `"SELECT (c1, c2)"`, false),
 				Check: resource.ComposeTestCheckFunc(
 					testAccPrivilege("mysql_grant.test", "SELECT (c1,c2)", true, false),
 					resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
@@ -214,7 +182,7 @@ func TestAccGrantComplex(t *testing.T) {
 				),
 			},
 			{
-				Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1)", "INSERT(c3, c4)", "REFERENCES(c5)"`),
+				Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1)", "INSERT(c3, c4)", "REFERENCES(c5)"`, false),
 				Check: resource.ComposeTestCheckFunc(
 					testAccPrivilege("mysql_grant.test", "INSERT (c3,c4)", true, false),
 					testAccPrivilege("mysql_grant.test", "SELECT (c1)", true, false),
@@ -227,7 +195,7 @@ func TestAccGrantComplex(t *testing.T) {
 				),
 			},
 			{
-				Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1)", "INSERT(c4, c3, c2)"`),
+				Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1)", "INSERT(c4, c3, c2)"`, false),
 				Check: resource.ComposeTestCheckFunc(
 					testAccPrivilege("mysql_grant.test", "REFERENCES (c5)", false, false),
 					resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
@@ -237,7 +205,7 @@ func TestAccGrantComplex(t *testing.T) {
 				),
 			},
 			{
-				Config: testAccGrantConfigWithPrivs(dbName, `"ALL PRIVILEGES"`),
+				Config: testAccGrantConfigWithPrivs(dbName, `"ALL PRIVILEGES"`, false),
 				Check: resource.ComposeTestCheckFunc(
 					testAccPrivilege("mysql_grant.test", "ALL", true, false),
 					resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
@@ -247,7 +215,7 @@ func TestAccGrantComplex(t *testing.T) {
 				),
 			},
 			{
-				Config: testAccGrantConfigWithPrivs(dbName, `"ALL"`),
+				Config: testAccGrantConfigWithPrivs(dbName, `"ALL"`, false),
 				Check: resource.ComposeTestCheckFunc(
 					testAccPrivilege("mysql_grant.test", "ALL", true, false),
 					resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
@@ -257,7 +225,7 @@ func TestAccGrantComplex(t *testing.T) {
 				),
 			},
 			{
-				Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1, c2)", "INSERT(c5)", "REFERENCES(c1)"`),
+				Config: testAccGrantConfigWithPrivs(dbName, `"DROP", "SELECT (c1, c2)", "INSERT(c5)", "REFERENCES(c1)"`, false),
 				Check: resource.ComposeTestCheckFunc(
 					testAccPrivilege("mysql_grant.test", "ALL", false, false),
 					testAccPrivilege("mysql_grant.test", "DROP", true, false),
@@ -270,6 +238,38 @@ func TestAccGrantComplex(t *testing.T) {
 					resource.TestCheckResourceAttr("mysql_grant.test", "table", "tbl"),
 				),
 			},
+			// Grant SELECT and UPDATE privileges WITH grant option
+			{
+				Config: testAccGrantConfigWithPrivs(dbName, `"SELECT (c1, c2)","UPDATE(c1, c2)"`, true),
+				Check: resource.ComposeTestCheckFunc(
+					testAccPrivilege("mysql_grant.test", "SELECT (c1,c2)", true, true),
+					testAccPrivilege("mysql_grant.test", "UPDATE (c1,c2)", true, true),
+					testAccPrivilege("mysql_grant.test", "ALL", false, true),
+					testAccPrivilege("mysql_grant.test", "DROP", false, true),
+					resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
+					resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
+					resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
+					resource.TestCheckResourceAttr("mysql_grant.test", "table", "tbl"),
+				),
+			},
+			// Grant ALL privileges WITH grant option
+			{
+				Config: testAccGrantConfigWithPrivs(dbName, `"ALL"`, true),
+				Check: resource.ComposeTestCheckFunc(
+					testAccPrivilege("mysql_grant.test", "ALL", true, true),
+					testAccPrivilege("mysql_grant.test", "SELECT (c1,c2)", false, true),
+					testAccPrivilege("mysql_grant.test", "UPDATE (c1,c2)", false, true),
+					testAccPrivilege("mysql_grant.test", "DROP", false, true),
+					resource.TestCheckResourceAttr("mysql_grant.test", "user", fmt.Sprintf("jdoe-%s", dbName)),
+					resource.TestCheckResourceAttr("mysql_grant.test", "host", "example.com"),
+					resource.TestCheckResourceAttr("mysql_grant.test", "database", dbName),
+					resource.TestCheckResourceAttr("mysql_grant.test", "table", "tbl"),
+				),
+			},
+			// Finally, revoke all privileges
+			{
+				Config: testAccGrantConfigNoGrant(dbName),
+			},
 		},
 	})
 }
@@ -540,7 +540,13 @@ resource "mysql_user" "test_global" {
 `, dbName, dbName, dbName)
 }
 
-func testAccGrantConfigWithPrivs(dbName, privs string) string {
+func testAccGrantConfigWithPrivs(dbName, privs string, grantOption bool) string {
+
+	grantOptionStr := "false"
+	if grantOption {
+		grantOptionStr = "true"
+	}
+
 	return fmt.Sprintf(`
 resource "mysql_database" "test" {
   name = "%s"
@@ -570,8 +576,9 @@ resource "mysql_grant" "test" {
   table      = "tbl"
   database   = "${mysql_database.test.name}"
   privileges = [%s]
+  grant      = %s
 }
-`, dbName, dbName, dbName, privs)
+`, dbName, dbName, dbName, privs, grantOptionStr)
 }
 
 func testAccGrantConfigWithDynamicMySQL8(dbName string) string {