Skip to content

Commit

Permalink
Remove "Optional[]" from getResourceGroupQueryType string
Browse files Browse the repository at this point in the history
TrinoQueryProperties.getResourceGroupQueryType() called toString on an Optional. This
unintended behavior produced output such as "Optional[DATA_DEFINITION]", complicating
rule writing for end users.

Co-authored-by: Yuya Ebihara <[email protected]>
  • Loading branch information
willmostly and ebyhr authored Jan 24, 2025
1 parent 26dc2fe commit 99cae08
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.gateway.ha.router;

import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
import io.trino.sql.tree.AddColumn;
import io.trino.sql.tree.Analyze;
import io.trino.sql.tree.Call;
Expand Down Expand Up @@ -89,7 +90,6 @@
import io.trino.sql.tree.Use;

import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.gateway.ha.router.QueryType.ALTER_TABLE_EXECUTE;
Expand All @@ -108,6 +108,8 @@
//modified version of io.trino.util.StatementUtils
public final class StatementUtils
{
private static final Logger log = Logger.get(StatementUtils.class);

private StatementUtils() {}

private static final Map<Class<? extends Statement>, StatementTypeInfo<? extends Statement>> STATEMENT_QUERY_TYPES = ImmutableList.<StatementTypeInfo<?>>builder()
Expand Down Expand Up @@ -191,13 +193,17 @@ private StatementUtils() {}
.build().stream()
.collect(toImmutableMap(StatementTypeInfo::getStatementType, identity()));

public static Optional<QueryType> getQueryType(Statement statement)
public static String getResourceGroupQueryType(Statement statement)
{
if (statement instanceof ExplainAnalyze) {
return getQueryType(((ExplainAnalyze) statement).getStatement());
return getResourceGroupQueryType(((ExplainAnalyze) statement).getStatement());
}
StatementTypeInfo<? extends Statement> statementTypeInfo = STATEMENT_QUERY_TYPES.get(statement.getClass());
if (statementTypeInfo != null) {
return statementTypeInfo.getQueryType().toString();
}
return Optional.ofNullable(STATEMENT_QUERY_TYPES.get(statement.getClass()))
.map(StatementTypeInfo::getQueryType);
log.warn("Unsupported statement type: %s", statement.getClass());
return "UNKNOWN";
}

private static <T extends Statement> StatementTypeInfo<T> basicStatement(Class<T> statementType, QueryType queryType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ else if (statement instanceof ExecuteImmediate executeImmediate) {
}

queryType = statement.getClass().getSimpleName();
resourceGroupQueryType = StatementUtils.getQueryType(statement).toString();
resourceGroupQueryType = StatementUtils.getResourceGroupQueryType(statement);
ImmutableSet.Builder<QualifiedName> tableBuilder = ImmutableSet.builder();
ImmutableSet.Builder<String> catalogBuilder = ImmutableSet.builder();
ImmutableSet.Builder<String> schemaBuilder = ImmutableSet.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ void testTrinoQueryPropertiesQueryType()
assertThat(routingGroupSelector.findRoutingGroup(mockRequest)).isEqualTo("type-group");
}

@Test
void testTrinoQueryPropertiesResourceGroupQueryType()
throws IOException
{
RoutingGroupSelector routingGroupSelector =
RoutingGroupSelector.byRoutingRulesEngine("src/test/resources/rules/routing_rules_trino_query_properties.yml", requestAnalyzerConfig);
HttpServletRequest mockRequest = prepareMockRequest();
when(mockRequest.getReader()).thenReturn(new BufferedReader(new StringReader("CREATE TABLE cat.schem.foo (c1 int)")));

assertThat(routingGroupSelector.findRoutingGroup(mockRequest)).isEqualTo("resource-group-type-group");
}

@Test
void testTrinoQueryPropertiesAlternateStatementFormat()
throws IOException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ condition: |
actions:
- "result.put(\"routingGroup\", \"type-group\")"
---
name: "resource-group-query-type"
description: "test table type"
condition: |
trinoQueryProperties.getResourceGroupQueryType().equals("DATA_DEFINITION")
actions:
- "result.put(\"routingGroup\", \"resource-group-type-group\")"
---
name: "prepared-statement-header"
description: "test execute with multiple prepared statements"
condition: |
Expand Down

0 comments on commit 99cae08

Please sign in to comment.