Skip to content

Commit

Permalink
Add tests for validation functions
Browse files Browse the repository at this point in the history
Signed-off-by: Derek Ho <[email protected]>
  • Loading branch information
derek-ho committed Dec 11, 2024
1 parent 0c43e9d commit 1468c9c
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
package org.opensearch.security.action.apitokens;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -40,7 +39,14 @@
public class ApiTokenAction extends BaseRestHandler {
private final ApiTokenRepository apiTokenRepository;

public static final String NAME_JSON_PROPERTY = "name";
private static final String NAME_JSON_PROPERTY = "name";
private static final String CLUSTER_PERMISSIONS_FIELD = "cluster_permissions";
private static final String INDEX_PERMISSIONS_FIELD = "index_permissions";
private static final String INDEX_PATTERN_FIELD = "index_pattern";
private static final String ALLOWED_ACTIONS_FIELD = "allowed_actions";
private static final String DLS_FIELD = "dls";
private static final String FLS_FIELD = "fls";
private static final String MASKED_FIELDS_FIELD = "masked_fields";

private static final List<RestHandler.Route> ROUTES = addRoutesPrefix(
ImmutableList.of(
Expand Down Expand Up @@ -106,14 +112,13 @@ private RestChannelConsumer handleGet(RestRequest request, NodeClient client) {
}

private RestChannelConsumer handlePost(RestRequest request, NodeClient client) {
// TODO: Enforce unique token description
return channel -> {
final XContentBuilder builder = channel.newBuilder();
BytesRestResponse response;
try {
final Map<String, Object> requestBody = request.contentOrSourceParamParser().map();
validateRequestParameters(requestBody);

validateRequestParametersForCreate(requestBody);
List<String> clusterPermissions = extractClusterPermissions(requestBody);
List<RoleV7.Index> indexPermissions = extractIndexPermissions(requestBody);

Expand All @@ -129,173 +134,170 @@ private RestChannelConsumer handlePost(RestRequest request, NodeClient client) {

response = new BytesRestResponse(RestStatus.OK, builder);
} catch (final Exception exception) {
builder.startObject().field("error", "An unexpected error occurred. Please check the input and try again.").endObject();
response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, builder);
}
builder.close();
channel.sendResponse(response);
};

}

private RestChannelConsumer handleDelete(RestRequest request, NodeClient client) {
return channel -> {
final XContentBuilder builder = channel.newBuilder();
BytesRestResponse response;
try {
final Map<String, Object> requestBody = request.contentOrSourceParamParser().map();

validateRequestParameters(requestBody);
apiTokenRepository.deleteApiToken((String) requestBody.get(NAME_JSON_PROPERTY));

builder.startObject();
builder.field("message", "token " + requestBody.get(NAME_JSON_PROPERTY) + " deleted successfully.");
builder.endObject();

response = new BytesRestResponse(RestStatus.OK, builder);
} catch (final ApiTokenException exception) {
builder.startObject().field("error", exception.getMessage()).endObject();
response = new BytesRestResponse(RestStatus.NOT_FOUND, builder);
} catch (final Exception exception) {
builder.startObject().field("error", "An unexpected error occurred. Please check the input and try again.").endObject();
builder.startObject()
.field("error", "An unexpected error occurred. Please check the input and try again.")
.field("message", exception.getMessage())
.endObject();
response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, builder);
}
builder.close();
channel.sendResponse(response);
};

}

private void validateRequestParameters(Map<String, Object> requestBody) {
if (!requestBody.containsKey(NAME_JSON_PROPERTY)) {
throw new IllegalArgumentException("Name parameter is required and cannot be empty.");
}
}

/**
* Validates the index permissions list structure
* Safely casts an Object to List<String> with validation
*/
private void validateIndexPermissionsList(List<Map<String, Object>> indexPermsList) {
for (Map<String, Object> indexPerm : indexPermsList) {
// Validate index pattern
if (!indexPerm.containsKey("index_pattern")) {
throw new IllegalArgumentException("Each index permission must contain an index_pattern");
}
Object indexPatternObj = indexPerm.get("index_pattern");
if (!(indexPatternObj instanceof String) && !(indexPatternObj instanceof List)) {
throw new IllegalArgumentException("index_pattern must be a string or array of strings");
}

// Validate allowed actions
if (!indexPerm.containsKey("allowed_actions")) {
throw new IllegalArgumentException("Each index permission must contain allowed_actions");
}
if (!(indexPerm.get("allowed_actions") instanceof List)) {
throw new IllegalArgumentException("allowed_actions must be an array");
}

// Validate DLS if present
if (indexPerm.containsKey("dls") && !(indexPerm.get("dls") instanceof String)) {
throw new IllegalArgumentException("dls must be a string");
}

// Validate FLS if present
if (indexPerm.containsKey("fls") && !(indexPerm.get("fls") instanceof List)) {
throw new IllegalArgumentException("fls must be an array");
}
List<String> safeStringList(Object obj, String fieldName) {
if (!(obj instanceof List<?> list)) {
throw new IllegalArgumentException(fieldName + " must be an array");
}

// Validate masked fields if present
if (indexPerm.containsKey("masked_fields") && !(indexPerm.get("masked_fields") instanceof List)) {
throw new IllegalArgumentException("masked_fields must be an array");
for (Object item : list) {
if (!(item instanceof String)) {
throw new IllegalArgumentException(fieldName + " must contain only strings");
}
}
}

private void validateRequestParametersForCreate(Map<String, Object> requestBody) {
if (!requestBody.containsKey(NAME_JSON_PROPERTY)) {
throw new IllegalArgumentException("Missing required parameter: " + NAME_JSON_PROPERTY);
}
return list.stream().map(String.class::cast).collect(Collectors.toList());
}

// Validate cluster permissions if present
if (requestBody.containsKey("cluster_permissions")) {
Object permissions = requestBody.get("cluster_permissions");
if (!(permissions instanceof List)) {
throw new IllegalArgumentException("cluster_permissions must be an array");
}
/**
* Safely casts an Object to List<Map<String, Object>> with validation
*/
@SuppressWarnings("unchecked")
List<Map<String, Object>> safeMapList(Object obj, String fieldName) {
if (!(obj instanceof List<?> list)) {
throw new IllegalArgumentException(fieldName + " must be an array");
}

// Validate index permissions if present
if (requestBody.containsKey("index_permissions")) {
Object indexPerms = requestBody.get("index_permissions");
if (!(indexPerms instanceof List)) {
throw new IllegalArgumentException("index_permissions must be an array");
for (Object item : list) {
if (!(item instanceof Map)) {
throw new IllegalArgumentException(fieldName + " must contain object entries");
}

@SuppressWarnings("unchecked")
List<Map<String, Object>> indexPermsList = (List<Map<String, Object>>) indexPerms;
validateIndexPermissionsList(indexPermsList);
}
return list.stream().map(item -> (Map<String, Object>) item).collect(Collectors.toList());
}

/**
* Extracts cluster permissions from the request body
*/
private List<String> extractClusterPermissions(Map<String, Object> requestBody) {
if (!requestBody.containsKey("cluster_permissions")) {
List<String> extractClusterPermissions(Map<String, Object> requestBody) {
if (!requestBody.containsKey(CLUSTER_PERMISSIONS_FIELD)) {
return Collections.emptyList();
}

@SuppressWarnings("unchecked")
List<String> permissions = (List<String>) requestBody.get("cluster_permissions");
return new ArrayList<>(permissions);
return safeStringList(requestBody.get(CLUSTER_PERMISSIONS_FIELD), CLUSTER_PERMISSIONS_FIELD);
}

/**
* Extracts and builds index permissions from the request body
*/
private List<RoleV7.Index> extractIndexPermissions(Map<String, Object> requestBody) {
if (!requestBody.containsKey("index_permissions")) {
List<RoleV7.Index> extractIndexPermissions(Map<String, Object> requestBody) {
if (!requestBody.containsKey(INDEX_PERMISSIONS_FIELD)) {
return Collections.emptyList();
}

@SuppressWarnings("unchecked")
List<Map<String, Object>> indexPerms = (List<Map<String, Object>>) requestBody.get("index_permissions");
List<Map<String, Object>> indexPerms = safeMapList(requestBody.get(INDEX_PERMISSIONS_FIELD), INDEX_PERMISSIONS_FIELD);

return indexPerms.stream().map(this::createIndexPermission).collect(Collectors.toList());
}

/**
* Creates a single RoleV7.Index permission from a permission map
*/
private RoleV7.Index createIndexPermission(Map<String, Object> indexPerm) {
// Get index patterns (can be single string or list)
RoleV7.Index createIndexPermission(Map<String, Object> indexPerm) {
List<String> indexPatterns;
Object indexPatternObj = indexPerm.get("index_pattern");
Object indexPatternObj = indexPerm.get(INDEX_PATTERN_FIELD);
if (indexPatternObj instanceof String) {
indexPatterns = Collections.singletonList((String) indexPatternObj);
} else {
@SuppressWarnings("unchecked")
List<String> patterns = (List<String>) indexPatternObj;
indexPatterns = patterns;
indexPatterns = safeStringList(indexPatternObj, INDEX_PATTERN_FIELD);
}

// Get allowed actions
@SuppressWarnings("unchecked")
List<String> allowedActions = (List<String>) indexPerm.get("allowed_actions");
List<String> allowedActions = safeStringList(indexPerm.get(ALLOWED_ACTIONS_FIELD), ALLOWED_ACTIONS_FIELD);

// Get DLS (Document Level Security)
String dls = (String) indexPerm.getOrDefault("dls", "");
String dls = (String) indexPerm.getOrDefault(DLS_FIELD, "");

// Get FLS (Field Level Security)
@SuppressWarnings("unchecked")
List<String> fls = indexPerm.containsKey("fls") ? (List<String>) indexPerm.get("fls") : Collections.emptyList();
List<String> fls = indexPerm.containsKey(FLS_FIELD) ? safeStringList(indexPerm.get(FLS_FIELD), FLS_FIELD) : Collections.emptyList();

// Get masked fields
@SuppressWarnings("unchecked")
List<String> maskedFields = indexPerm.containsKey("masked_fields")
? (List<String>) indexPerm.get("masked_fields")
List<String> maskedFields = indexPerm.containsKey(MASKED_FIELDS_FIELD)
? safeStringList(indexPerm.get(MASKED_FIELDS_FIELD), MASKED_FIELDS_FIELD)
: Collections.emptyList();

return new RoleV7.Index(indexPatterns, allowedActions, dls, fls, maskedFields);
}

/**
* Validates the request parameters
*/
void validateRequestParameters(Map<String, Object> requestBody) {
if (!requestBody.containsKey(NAME_JSON_PROPERTY)) {
throw new IllegalArgumentException("Missing required parameter: " + NAME_JSON_PROPERTY);
}

if (requestBody.containsKey(CLUSTER_PERMISSIONS_FIELD)) {
Object permissions = requestBody.get(CLUSTER_PERMISSIONS_FIELD);
if (!(permissions instanceof List)) {
throw new IllegalArgumentException(CLUSTER_PERMISSIONS_FIELD + " must be an array");
}
}

if (requestBody.containsKey(INDEX_PERMISSIONS_FIELD)) {
List<Map<String, Object>> indexPermsList = safeMapList(requestBody.get(INDEX_PERMISSIONS_FIELD), INDEX_PERMISSIONS_FIELD);
validateIndexPermissionsList(indexPermsList);
}
}

/**
* Validates the index permissions list structure
*/
void validateIndexPermissionsList(List<Map<String, Object>> indexPermsList) {
for (Map<String, Object> indexPerm : indexPermsList) {
if (!indexPerm.containsKey(INDEX_PATTERN_FIELD)) {
throw new IllegalArgumentException("Each index permission must contain " + INDEX_PATTERN_FIELD);
}
if (!indexPerm.containsKey(ALLOWED_ACTIONS_FIELD)) {
throw new IllegalArgumentException("Each index permission must contain " + ALLOWED_ACTIONS_FIELD);
}

Object indexPatternObj = indexPerm.get(INDEX_PATTERN_FIELD);
if (!(indexPatternObj instanceof String) && !(indexPatternObj instanceof List)) {
throw new IllegalArgumentException(INDEX_PATTERN_FIELD + " must be a string or array of strings");
}

if (indexPerm.containsKey(DLS_FIELD) && !(indexPerm.get(DLS_FIELD) instanceof String)) {
throw new IllegalArgumentException(DLS_FIELD + " must be a string");
}
}
}

private RestChannelConsumer handleDelete(RestRequest request, NodeClient client) {
return channel -> {
final XContentBuilder builder = channel.newBuilder();
BytesRestResponse response;
try {
final Map<String, Object> requestBody = request.contentOrSourceParamParser().map();

validateRequestParameters(requestBody);
apiTokenRepository.deleteApiToken((String) requestBody.get(NAME_JSON_PROPERTY));

builder.startObject();
builder.field("message", "token " + requestBody.get(NAME_JSON_PROPERTY) + " deleted successfully.");
builder.endObject();

response = new BytesRestResponse(RestStatus.OK, builder);
} catch (final ApiTokenException exception) {
builder.startObject().field("error", exception.getMessage()).endObject();
response = new BytesRestResponse(RestStatus.NOT_FOUND, builder);
} catch (final Exception exception) {
builder.startObject().field("error", "An unexpected error occurred. Please check the input and try again.").endObject();
response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, builder);
}
builder.close();
channel.sendResponse(response);
};

}

}
Loading

0 comments on commit 1468c9c

Please sign in to comment.