Skip to content

Commit

Permalink
runtime-v2: use SensitiveDataHolder for task parameter masking (#1050)
Browse files Browse the repository at this point in the history
  • Loading branch information
ibodrov authored Dec 31, 2024
1 parent 4376458 commit c98cc85
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.walmartlabs.concord.runtime.v2.model.Location;
import com.walmartlabs.concord.runtime.v2.model.Step;
import com.walmartlabs.concord.runtime.v2.runner.EventReportingService;
import com.walmartlabs.concord.runtime.v2.runner.SensitiveDataHolder;
import com.walmartlabs.concord.runtime.v2.runner.tasks.TaskCallEvent;
import com.walmartlabs.concord.runtime.v2.runner.tasks.TaskCallListener;
import com.walmartlabs.concord.runtime.v2.sdk.*;
Expand Down Expand Up @@ -64,7 +65,9 @@ public void onEvent(TaskCallEvent event) {

List<Object> inVars = event.input();
if (inVars != null && eventConfiguration.recordTaskInVars()) {
Map<String, Object> vars = maskVars(convertInput(hideSensitiveData(inVars, event.inputAnnotations())), eventConfiguration.inVarsBlacklist());
Map<String, Object> input = convertInput(processSensitiveDataAnnotations(inVars, event.inputAnnotations()));
input = processSensitiveData(input);
Map<String, Object> vars = maskVars(input, eventConfiguration.inVarsBlacklist());
if (eventConfiguration.truncateInVars()) {
vars = ObjectTruncater.truncateMap(vars, eventConfiguration.truncateMaxStringLength(), eventConfiguration.truncateMaxArrayLength(), eventConfiguration.truncateMaxDepth());
}
Expand All @@ -75,7 +78,9 @@ public void onEvent(TaskCallEvent event) {

Object outVars = event.result();
if (outVars != null && eventConfiguration.recordTaskOutVars()) {
Map<String, Object> vars = maskVars(asMapOrNull(outVars), eventConfiguration.outVarsBlacklist());
Map<String, Object> output = asMapOrNull(outVars);
output = processSensitiveData(output);
Map<String, Object> vars = maskVars(output, eventConfiguration.outVarsBlacklist());
if (eventConfiguration.truncateOutVars()) {
vars = ObjectTruncater.truncateMap(vars, eventConfiguration.truncateMaxStringLength(), eventConfiguration.truncateMaxArrayLength(), eventConfiguration.truncateMaxDepth());
}
Expand All @@ -86,7 +91,9 @@ public void onEvent(TaskCallEvent event) {

Object metaVars = event.meta();
if (metaVars != null && eventConfiguration.recordTaskMeta()) {
Map<String, Object> meta = maskVars(asMapOrNull(metaVars), eventConfiguration.metaBlacklist());
Map<String, Object> rawMeta = asMapOrNull(metaVars);
Map<String, Object> meta = processSensitiveData(rawMeta);
meta = maskVars(meta, eventConfiguration.metaBlacklist());
if (eventConfiguration.truncateMeta()) {
meta = ObjectTruncater.truncateMap(meta, eventConfiguration.truncateMaxStringLength(), eventConfiguration.truncateMaxArrayLength(), eventConfiguration.truncateMaxDepth());
}
Expand Down Expand Up @@ -170,6 +177,34 @@ static Map<String, Object> maskVars(Map<String, Object> vars, Collection<String>
return result;
}

@SuppressWarnings({"unchecked", "rawtypes"})
static <T> T processSensitiveData(T v) {
Set<String> sensitiveStrings = SensitiveDataHolder.getInstance().get();
if (sensitiveStrings.isEmpty()) {
return v;
}

if (v instanceof String s) {
for (String sensitiveString : sensitiveStrings) {
s = s.replace(sensitiveString, MASK);
}
return (T) s;
} else if (v instanceof List<?> l) {
List<Object> result = new ArrayList<>(l.size());
for (Object vv : l) {
vv = processSensitiveData(vv);
result.add(vv);
}
return (T) result;
} else if (v instanceof Map m) {
Map<String, Object> result = new HashMap<>(m);
result.replaceAll((k, vv) -> processSensitiveData(vv));
return (T) result;
}

return v;
}

@SuppressWarnings("unchecked")
private static Map<String, Object> ensureModifiable(Map<String, Object> m, int depth, String[] path) {
if (depth == 0) {
Expand Down Expand Up @@ -217,7 +252,7 @@ private static Map<String, Object> convertInput(List<Object> input) {
return result;
}

private static List<Object> hideSensitiveData(List<Object> input, List<List<Annotation>> annotations) {
private static List<Object> processSensitiveDataAnnotations(List<Object> input, List<List<Annotation>> annotations) {
if (annotations.isEmpty()) {
return input;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.walmartlabs.concord.runtime.v2.runner.SensitiveDataHolder;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
Expand All @@ -38,33 +39,33 @@ public class TaskCallEventRecordingListenerTest {
@Test
public void testMaskVars() throws Exception {
String in = "{" +
" \"a\":1," +
" \"b\":2," +
" \"c\":{" +
" \"c1\":3," +
" \"c2\":4," +
" \"c3\":{" +
" \"c31\":5," +
" \"c32\":6" +
" }" +
" }" +
"}";
" \"a\":1," +
" \"b\":2," +
" \"c\":{" +
" \"c1\":3," +
" \"c2\":4," +
" \"c3\":{" +
" \"c31\":5," +
" \"c32\":6" +
" }" +
" }" +
"}";

List<String> blackList = Arrays.asList("b", "c.c1", "c.c3.c31");
Map<String, Object> result = TaskCallEventRecordingListener.maskVars(vars(in), blackList);

String expected = "{" +
" \"a\":1," +
" \"b\":\"***\"," +
" \"c\":{" +
" \"c1\":\"***\"," +
" \"c2\":4," +
" \"c3\":{" +
" \"c31\":\"***\"," +
" \"c32\":6" +
" }" +
" }" +
"}";
" \"a\":1," +
" \"b\":\"***\"," +
" \"c\":{" +
" \"c1\":\"***\"," +
" \"c2\":4," +
" \"c3\":{" +
" \"c31\":\"***\"," +
" \"c32\":6" +
" }" +
" }" +
"}";
assertEquals(vars(expected), result);
}

Expand All @@ -80,6 +81,29 @@ public void testMaskVarsUnmodifiable() {
assertEquals("{x={y={z=***}}}", result.toString());
}

@Test
public void testSensitiveDataMasking() throws JsonProcessingException {
SensitiveDataHolder holder = SensitiveDataHolder.getInstance();
holder.add("foo");
holder.add("bar");

String in = "{" +
"\"a\": \"foo\"," +
"\"b\": \"bar\"," +
"\"c\": \"baz\"," +
"\"d\": { \"e\": \"foo\" }" +
"}";

Map<String, Object> result = TaskCallEventRecordingListener.processSensitiveData(vars(in));
String expected = "{" +
" \"a\": \"***\"," +
" \"b\": \"***\"," +
" \"c\": \"baz\"," +
" \"d\": { \"e\": \"***\" }" +
"}";
assertEquals(vars(expected), result);
}

@SuppressWarnings("unchecked")
private static Map<String, Object> vars(String in) throws JsonProcessingException {
return om.readValue(in, Map.class);
Expand Down

0 comments on commit c98cc85

Please sign in to comment.