Skip to content

Commit

Permalink
Fixes #213
Browse files Browse the repository at this point in the history
and tests commutativity
  • Loading branch information
Sergey Shekyan committed Dec 18, 2018
1 parent 79eb836 commit 36b24e0
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 53 deletions.
83 changes: 45 additions & 38 deletions src/main/java/com/shapesecurity/salvation/data/Policy.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -52,41 +51,53 @@ public void setOrigin(@Nonnull Origin origin) {
}

public void intersect(@Nonnull Policy other) {
this.mergeUsingStrategy(other, this::intersectDirectivePrivate);
}
checkForMergeValidity(this);
checkForMergeValidity(other);

public void union(@Nonnull Policy other) {
this.mergeUsingStrategy(other, this::unionDirectivePrivate);
}
this.resolveSelf();
other.resolveSelf();

private void mergeUsingStrategy(@Nonnull Policy other, Consumer<Directive<? extends DirectiveValue>> strategy) {
if (this.directives.containsKey(ReportUriDirective.class) || other.directives
.containsKey(ReportUriDirective.class)) {
throw new IllegalArgumentException(
"Cannot merge policies if either policy contains a report-uri directive.");
}
this.expandDefaultSrc();
other.expandDefaultSrc();

if (this.directives.containsKey(ReportToDirective.class) || other.directives
.containsKey(ReportToDirective.class)) {
throw new IllegalArgumentException(
"Cannot merge policies if either policy contains a report-to directive.");
for (Map.Entry<Class<?>, Directive<? extends DirectiveValue>> entry : other.directives.entrySet()) {
this.intersectDirectivePrivate(entry.getValue());
}

if (this.directives.containsKey(ReferrerDirective.class) || other.directives
.containsKey(ReferrerDirective.class)) {
throw new IllegalArgumentException(
"Cannot merge policies if either policy contains a referrer directive.");
}
this.optimise();
}

public void union(@Nonnull Policy other) {
checkForMergeValidity(this);
checkForMergeValidity(other);

this.resolveSelf();
other.resolveSelf();

this.expandDefaultSrc();
other.expandDefaultSrc();

other.getDirectives().forEach(strategy);
for (Map.Entry<Class<?>, Directive<? extends DirectiveValue>> entry : other.directives.entrySet()) {
this.unionDirectivePrivate(entry.getValue());
}
this.directives.entrySet().removeIf(entry -> !other.directives.containsKey(entry.getKey()));

this.optimise();
other.optimise();
}

private static void checkForMergeValidity(@Nonnull Policy p) {
if (p.directives.containsKey(ReportUriDirective.class)) {
throw new IllegalArgumentException("Cannot merge policies if either policy contains a report-uri directive.");
}

if (p.directives.containsKey(ReportToDirective.class)) {
throw new IllegalArgumentException("Cannot merge policies if either policy contains a report-to directive.");
}

if (p.directives.containsKey(ReferrerDirective.class)) {
throw new IllegalArgumentException("Cannot merge policies if either policy contains a referrer directive.");
}
}

private void resolveSelf() {
Expand All @@ -106,49 +117,49 @@ private void expandDefaultSrc() {
defaultSources = defaultSrcDirective.values().collect(Collectors.toCollection(LinkedHashSet::new));

if (!this.directives.containsKey(ScriptSrcDirective.class)) {
this.unionDirectivePrivate(new ScriptSrcDirective(defaultSources));
this.directives.put(ScriptSrcDirective.class, new ScriptSrcDirective(defaultSources));
}
if (!this.directives.containsKey(StyleSrcDirective.class)) {
this.unionDirectivePrivate(new StyleSrcDirective(defaultSources));
this.directives.put(StyleSrcDirective.class, new StyleSrcDirective(defaultSources));
}
if (!this.directives.containsKey(ImgSrcDirective.class)) {
this.unionDirectivePrivate(new ImgSrcDirective(defaultSources));
this.directives.put(ImgSrcDirective.class, new ImgSrcDirective(defaultSources));
}
if (!this.directives.containsKey(ChildSrcDirective.class)) {
this.unionDirectivePrivate(new ChildSrcDirective(defaultSources));
this.directives.put(ChildSrcDirective.class, new ChildSrcDirective(defaultSources));
}
if (!this.directives.containsKey(ConnectSrcDirective.class)) {
this.unionDirectivePrivate(new ConnectSrcDirective(defaultSources));
this.directives.put(ConnectSrcDirective.class, new ConnectSrcDirective(defaultSources));
}
if (!this.directives.containsKey(FontSrcDirective.class)) {
this.unionDirectivePrivate(new FontSrcDirective(defaultSources));
this.directives.put(FontSrcDirective.class, new FontSrcDirective(defaultSources));
}
if (!this.directives.containsKey(MediaSrcDirective.class)) {
this.unionDirectivePrivate(new MediaSrcDirective(defaultSources));
this.directives.put(MediaSrcDirective.class, new MediaSrcDirective(defaultSources));
}
if (!this.directives.containsKey(ObjectSrcDirective.class)) {
this.unionDirectivePrivate(new ObjectSrcDirective(defaultSources));
this.directives.put(ObjectSrcDirective.class, new ObjectSrcDirective(defaultSources));
}
if (!this.directives.containsKey(ManifestSrcDirective.class)) {
this.unionDirectivePrivate(new ManifestSrcDirective(defaultSources));
this.directives.put(ManifestSrcDirective.class, new ManifestSrcDirective(defaultSources));
}
if (!this.directives.containsKey(PrefetchSrcDirective.class)) {
this.unionDirectivePrivate(new PrefetchSrcDirective(defaultSources));
this.directives.put(PrefetchSrcDirective.class, new PrefetchSrcDirective(defaultSources));
}
}

// expand child-src
if (this.directives.containsKey(ChildSrcDirective.class) && !this.directives.containsKey(FrameSrcDirective.class)) {
ChildSrcDirective childSrcDirective = this.getDirectiveByType(ChildSrcDirective.class);
Set<SourceExpression> childSources = childSrcDirective.values().collect(Collectors.toCollection(LinkedHashSet::new));
this.unionDirectivePrivate(new FrameSrcDirective(childSources));
this.directives.put(FrameSrcDirective.class, new FrameSrcDirective(childSources));
}

// expand script-src
if (this.directives.containsKey(ScriptSrcDirective.class) && !this.directives.containsKey(WorkerSrcDirective.class)) {
ScriptSrcDirective scriptSrcDirective = this.getDirectiveByType(ScriptSrcDirective.class);
Set<SourceExpression> scriptSources = scriptSrcDirective.values().collect(Collectors.toCollection(LinkedHashSet::new));
this.unionDirectivePrivate(new WorkerSrcDirective(scriptSources));
this.directives.put(WorkerSrcDirective.class, new WorkerSrcDirective(scriptSources));
}
}

Expand Down Expand Up @@ -338,10 +349,6 @@ private <V extends DirectiveValue, T extends Directive<V>> void unionDirectivePr
@SuppressWarnings("unchecked") T oldDirective = (T) this.directives.get(directive.getClass());
if (oldDirective != null) {
oldDirective.union(directive);
} else {
if (!(directive instanceof FetchDirective) || this.containsFetchDirective()) {
this.directives.put(directive.getClass(), directive);
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/test/java/com/shapesecurity/salvation/ParserTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,11 @@ public void testPolicy() {

Policy c = parse("script-src *");
b.union(c);
assertEquals("policy union", "style-src *; script-src *", b.show());
assertEquals("policy union", "", b.show());

Policy d = parse("script-src abc");
b.union(d);
assertEquals("policy union", "style-src *; script-src *", b.show());
assertEquals("policy union", "", b.show());

a.setOrigin(URI.parse("http://qwe.zz:80"));
assertEquals("policy origin", "http://qwe.zz", a.getOrigin().show());
Expand Down
86 changes: 73 additions & 13 deletions src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class PolicyMergeTest extends CSPTest {
Expand Down Expand Up @@ -52,18 +53,18 @@ public void testUnionNonFetchDirectives() {
.parse("default-src; form-action a; frame-ancestors b; navigate-to c", "https://origin2.com");
p1.union(p2);
assertEquals(
"default-src a; form-action a; frame-ancestors b; navigate-to c",
"default-src a",
p1.show());

p1 = Parser.parse("form-action aaa; frame-ancestors bbb; navigate-to ccc", "https://origin1.com");
p2 = Parser.parse("script-src a", "https://origin1.com");
p1.union(p2);
assertEquals("form-action aaa; frame-ancestors bbb; navigate-to ccc", p1.show());
assertEquals("", p1.show());

p1 = Parser.parse("frame-ancestors bbb;", "https://origin1.com");
p2 = Parser.parse("script-src a", "https://origin1.com");
p1.union(p2);
assertEquals("frame-ancestors bbb", p1.show());
assertEquals("", p1.show());

p1 = Parser.parse("", "https://origin1.com");
p2 = Parser.parse("script-src a", "https://origin1.com");
Expand Down Expand Up @@ -580,19 +581,19 @@ public void testUnionChildSrc() {
q = parse("worker-src a");
p.union(q);
p.postProcessOptimisation();
assertEquals("child-src a; worker-src a", p.show());
assertEquals("", p.show());

p = parse("child-src a; ");
q = parse("worker-src a ");
q = parse("frame-src a ");
p.union(q);
p.postProcessOptimisation();
assertEquals("child-src a; worker-src a", p.show());
assertEquals("frame-src a", p.show());

p = parse("child-src a b");
q = parse("child-src; worker-src x; frame-src y");
p.union(q);
p.postProcessOptimisation();
assertEquals("child-src a b; frame-src a b y; worker-src x", p.show());
assertEquals("child-src a b; frame-src a b y", p.show());

p = parse("child-src *; worker-src");
q = parse("child-src; worker-src b");
Expand All @@ -607,7 +608,7 @@ public void testUnionChildSrc() {
p = parse("child-src a");
q = parse("child-src; worker-src b");
p.union(q);
assertEquals("child-src a; worker-src b", p.show());
assertEquals("child-src a", p.show());

p = parse("child-src a; worker-src b");
q = parse("child-src; worker-src c");
Expand All @@ -617,26 +618,36 @@ public void testUnionChildSrc() {
p = parse("child-src; worker-src a; frame-src b");
q = parse("child-src c");
p.union(q);
assertEquals("child-src c; worker-src a; frame-src b c", p.show());
assertEquals("child-src c; frame-src b c", p.show());

p = parse("child-src a; worker-src b");
q = parse("child-src c; frame-src d");
p.union(q);
assertEquals("child-src a c; worker-src b; frame-src a d", p.show());
assertEquals("child-src a c; frame-src a d", p.show());

p = parse("child-src b; worker-src a");
q = parse("child-src a");
p.union(q);
assertEquals("child-src b a; worker-src a", p.show());
assertEquals("child-src b a", p.show());

p = parse("default-src a");
q = parse("worker-src b; frame-src b;");
p.union(q);
assertEquals("default-src a; frame-src a b; worker-src a b", p.show());
assertEquals("frame-src a b; worker-src a b", p.show());
}

@Test
public void testUnionNone() {
Policy x = Parser.parse("frame-ancestors https://foo.bar", "http://example.com");
Policy y = Parser.parse("default-src 'none'", "http://example.com");
x.union(y);
assertEquals("", x.show());

x = Parser.parse("frame-ancestors https://foo.bar", "http://example.com");
y = Parser.parse("default-src 'none'; frame-ancestors;", "http://example.com");
x.union(y);
assertEquals("frame-ancestors https://foo.bar", x.show());

Policy p = parse("frame-ancestors 'none'");
Policy q = parse("frame-ancestors 'self'");
p.union(q);
Expand All @@ -645,7 +656,7 @@ public void testUnionNone() {
p = parse("frame-ancestors 'none' 'none'");
q = parse("frame-ancestors 'self'");
p.union(q);
assertEquals("frame-ancestors 'self'", p.show());
assertEquals("", p.show());

p = parse("frame-ancestors 'self'");
q = parse("frame-ancestors 'none'");
Expand Down Expand Up @@ -697,4 +708,53 @@ public void testUnionNone() {
p.union(q);
assertEquals("script-src *", p.show());
}

@Test
public void testMergeCommutativity() {
String[] policies = new String[] {
"script-src 'self'",
"script-src 'none'",
"script-src a",
"script-src a custom:",
"style-src 'self'",
"style-src 'none'",
"style-src a",
"style-src a custom:",
"default-src 'none'",
"default-src a",
"default-src custom:",
"plugin-types a/b",
"frame-ancestors 'self'",
"frame-ancestors 'none'",
"frame-ancestors a",
"frame-ancestors a custom:",
"frame-ancestors custom:",
"upgrade-insecure-requests",
"script-src a; frame-ancestors b"
};

for (int i = 0; i < policies.length; i++) {
for (int k = 0; k < policies.length; k++) {
Policy pq = parse(policies[i]);
pq.union(parse(policies[k]));

Policy qp = parse(policies[k]);
qp.union(parse(policies[i]));

assertTrue(pq.show() + " ≠ " + qp.show(), pq.equals(qp));
}
}

for (int i = 0; i < policies.length; i++) {
for (int k = 0; k < policies.length; k++) {
Policy pq = parse(policies[i]);
pq.intersect(parse(policies[k]));

Policy qp = parse(policies[k]);
qp.intersect(parse(policies[i]));

assertTrue(pq.show() + " ≠ " + qp.show(), pq.equals(qp));
}
}
}
}

0 comments on commit 36b24e0

Please sign in to comment.