From 5f6b5798c57c4e2dd377e41892ab93f90dd599b5 Mon Sep 17 00:00:00 2001 From: Sergey Shekyan Date: Tue, 18 Dec 2018 14:02:04 -0800 Subject: [PATCH] Fixes #213 and tests commutativity --- .../shapesecurity/salvation/data/Policy.java | 84 ++++++++++-------- .../shapesecurity/salvation/ParserTest.java | 4 +- .../salvation/PolicyMergeTest.java | 86 ++++++++++++++++--- 3 files changed, 121 insertions(+), 53 deletions(-) diff --git a/src/main/java/com/shapesecurity/salvation/data/Policy.java b/src/main/java/com/shapesecurity/salvation/data/Policy.java index 13fb42d2..923d2c56 100644 --- a/src/main/java/com/shapesecurity/salvation/data/Policy.java +++ b/src/main/java/com/shapesecurity/salvation/data/Policy.java @@ -8,7 +8,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -52,31 +51,26 @@ public void setOrigin(@Nonnull Origin origin) { } public void intersect(@Nonnull Policy other) { - this.mergeUsingStrategy(other, this::intersectDirectivePrivate); - } + this.checkForMergeValidity(); + other.checkForMergeValidity(); - public void union(@Nonnull Policy other) { - this.mergeUsingStrategy(other, this::unionDirectivePrivate); - } + this.resolveSelf(); + other.resolveSelf(); - private void mergeUsingStrategy(@Nonnull Policy other, Consumer> strategy) { - if (this.directives.containsKey(ReportUriDirective.class) || other.directives - .containsKey(ReportUriDirective.class)) { - throw new IllegalArgumentException( - "Cannot merge policies if either policy contains a report-uri directive."); - } + this.expandDefaultSrc(); + other.expandDefaultSrc(); - if (this.directives.containsKey(ReportToDirective.class) || other.directives - .containsKey(ReportToDirective.class)) { - throw new IllegalArgumentException( - "Cannot merge policies if either policy contains a report-to directive."); + for (Map.Entry, Directive> entry : other.directives.entrySet()) { + this.intersectDirectivePrivate(entry.getValue()); } - if (this.directives.containsKey(ReferrerDirective.class) || other.directives - .containsKey(ReferrerDirective.class)) { - throw new IllegalArgumentException( - "Cannot merge policies if either policy contains a referrer directive."); - } + this.optimise(); + other.optimise(); + } + + public void union(@Nonnull Policy other) { + this.checkForMergeValidity(); + other.checkForMergeValidity(); this.resolveSelf(); other.resolveSelf(); @@ -84,9 +78,27 @@ private void mergeUsingStrategy(@Nonnull Policy other, Consumer, Directive> entry : other.directives.entrySet()) { + this.unionDirectivePrivate(entry.getValue()); + } + this.directives.entrySet().removeIf(entry -> !other.directives.containsKey(entry.getKey())); this.optimise(); + other.optimise(); + } + + private void checkForMergeValidity() { + if (this.directives.containsKey(ReportUriDirective.class)) { + throw new IllegalArgumentException("Cannot merge policies if either policy contains a report-uri directive."); + } + + if (this.directives.containsKey(ReportToDirective.class)) { + throw new IllegalArgumentException("Cannot merge policies if either policy contains a report-to directive."); + } + + if (this.directives.containsKey(ReferrerDirective.class)) { + throw new IllegalArgumentException("Cannot merge policies if either policy contains a referrer directive."); + } } private void resolveSelf() { @@ -106,34 +118,34 @@ private void expandDefaultSrc() { defaultSources = defaultSrcDirective.values().collect(Collectors.toCollection(LinkedHashSet::new)); if (!this.directives.containsKey(ScriptSrcDirective.class)) { - this.unionDirectivePrivate(new ScriptSrcDirective(defaultSources)); + this.directives.put(ScriptSrcDirective.class, new ScriptSrcDirective(defaultSources)); } if (!this.directives.containsKey(StyleSrcDirective.class)) { - this.unionDirectivePrivate(new StyleSrcDirective(defaultSources)); + this.directives.put(StyleSrcDirective.class, new StyleSrcDirective(defaultSources)); } if (!this.directives.containsKey(ImgSrcDirective.class)) { - this.unionDirectivePrivate(new ImgSrcDirective(defaultSources)); + this.directives.put(ImgSrcDirective.class, new ImgSrcDirective(defaultSources)); } if (!this.directives.containsKey(ChildSrcDirective.class)) { - this.unionDirectivePrivate(new ChildSrcDirective(defaultSources)); + this.directives.put(ChildSrcDirective.class, new ChildSrcDirective(defaultSources)); } if (!this.directives.containsKey(ConnectSrcDirective.class)) { - this.unionDirectivePrivate(new ConnectSrcDirective(defaultSources)); + this.directives.put(ConnectSrcDirective.class, new ConnectSrcDirective(defaultSources)); } if (!this.directives.containsKey(FontSrcDirective.class)) { - this.unionDirectivePrivate(new FontSrcDirective(defaultSources)); + this.directives.put(FontSrcDirective.class, new FontSrcDirective(defaultSources)); } if (!this.directives.containsKey(MediaSrcDirective.class)) { - this.unionDirectivePrivate(new MediaSrcDirective(defaultSources)); + this.directives.put(MediaSrcDirective.class, new MediaSrcDirective(defaultSources)); } if (!this.directives.containsKey(ObjectSrcDirective.class)) { - this.unionDirectivePrivate(new ObjectSrcDirective(defaultSources)); + this.directives.put(ObjectSrcDirective.class, new ObjectSrcDirective(defaultSources)); } if (!this.directives.containsKey(ManifestSrcDirective.class)) { - this.unionDirectivePrivate(new ManifestSrcDirective(defaultSources)); + this.directives.put(ManifestSrcDirective.class, new ManifestSrcDirective(defaultSources)); } if (!this.directives.containsKey(PrefetchSrcDirective.class)) { - this.unionDirectivePrivate(new PrefetchSrcDirective(defaultSources)); + this.directives.put(PrefetchSrcDirective.class, new PrefetchSrcDirective(defaultSources)); } } @@ -141,14 +153,14 @@ private void expandDefaultSrc() { if (this.directives.containsKey(ChildSrcDirective.class) && !this.directives.containsKey(FrameSrcDirective.class)) { ChildSrcDirective childSrcDirective = this.getDirectiveByType(ChildSrcDirective.class); Set childSources = childSrcDirective.values().collect(Collectors.toCollection(LinkedHashSet::new)); - this.unionDirectivePrivate(new FrameSrcDirective(childSources)); + this.directives.put(FrameSrcDirective.class, new FrameSrcDirective(childSources)); } // expand script-src if (this.directives.containsKey(ScriptSrcDirective.class) && !this.directives.containsKey(WorkerSrcDirective.class)) { ScriptSrcDirective scriptSrcDirective = this.getDirectiveByType(ScriptSrcDirective.class); Set scriptSources = scriptSrcDirective.values().collect(Collectors.toCollection(LinkedHashSet::new)); - this.unionDirectivePrivate(new WorkerSrcDirective(scriptSources)); + this.directives.put(WorkerSrcDirective.class, new WorkerSrcDirective(scriptSources)); } } @@ -338,10 +350,6 @@ private > void unionDirectivePr @SuppressWarnings("unchecked") T oldDirective = (T) this.directives.get(directive.getClass()); if (oldDirective != null) { oldDirective.union(directive); - } else { - if (!(directive instanceof FetchDirective) || this.containsFetchDirective()) { - this.directives.put(directive.getClass(), directive); - } } } diff --git a/src/test/java/com/shapesecurity/salvation/ParserTest.java b/src/test/java/com/shapesecurity/salvation/ParserTest.java index 659090fe..778aae48 100644 --- a/src/test/java/com/shapesecurity/salvation/ParserTest.java +++ b/src/test/java/com/shapesecurity/salvation/ParserTest.java @@ -346,11 +346,11 @@ public void testPolicy() { Policy c = parse("script-src *"); b.union(c); - assertEquals("policy union", "style-src *; script-src *", b.show()); + assertEquals("policy union", "", b.show()); Policy d = parse("script-src abc"); b.union(d); - assertEquals("policy union", "style-src *; script-src *", b.show()); + assertEquals("policy union", "", b.show()); a.setOrigin(URI.parse("http://qwe.zz:80")); assertEquals("policy origin", "http://qwe.zz", a.getOrigin().show()); diff --git a/src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java b/src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java index 68fd3420..b20460d0 100644 --- a/src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java +++ b/src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java @@ -15,6 +15,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public class PolicyMergeTest extends CSPTest { @@ -52,18 +53,18 @@ public void testUnionNonFetchDirectives() { .parse("default-src; form-action a; frame-ancestors b; navigate-to c", "https://origin2.com"); p1.union(p2); assertEquals( - "default-src a; form-action a; frame-ancestors b; navigate-to c", + "default-src a", p1.show()); p1 = Parser.parse("form-action aaa; frame-ancestors bbb; navigate-to ccc", "https://origin1.com"); p2 = Parser.parse("script-src a", "https://origin1.com"); p1.union(p2); - assertEquals("form-action aaa; frame-ancestors bbb; navigate-to ccc", p1.show()); + assertEquals("", p1.show()); p1 = Parser.parse("frame-ancestors bbb;", "https://origin1.com"); p2 = Parser.parse("script-src a", "https://origin1.com"); p1.union(p2); - assertEquals("frame-ancestors bbb", p1.show()); + assertEquals("", p1.show()); p1 = Parser.parse("", "https://origin1.com"); p2 = Parser.parse("script-src a", "https://origin1.com"); @@ -580,19 +581,19 @@ public void testUnionChildSrc() { q = parse("worker-src a"); p.union(q); p.postProcessOptimisation(); - assertEquals("child-src a; worker-src a", p.show()); + assertEquals("", p.show()); p = parse("child-src a; "); - q = parse("worker-src a "); + q = parse("frame-src a "); p.union(q); p.postProcessOptimisation(); - assertEquals("child-src a; worker-src a", p.show()); + assertEquals("frame-src a", p.show()); p = parse("child-src a b"); q = parse("child-src; worker-src x; frame-src y"); p.union(q); p.postProcessOptimisation(); - assertEquals("child-src a b; frame-src a b y; worker-src x", p.show()); + assertEquals("child-src a b; frame-src a b y", p.show()); p = parse("child-src *; worker-src"); q = parse("child-src; worker-src b"); @@ -607,7 +608,7 @@ public void testUnionChildSrc() { p = parse("child-src a"); q = parse("child-src; worker-src b"); p.union(q); - assertEquals("child-src a; worker-src b", p.show()); + assertEquals("child-src a", p.show()); p = parse("child-src a; worker-src b"); q = parse("child-src; worker-src c"); @@ -617,26 +618,36 @@ public void testUnionChildSrc() { p = parse("child-src; worker-src a; frame-src b"); q = parse("child-src c"); p.union(q); - assertEquals("child-src c; worker-src a; frame-src b c", p.show()); + assertEquals("child-src c; frame-src b c", p.show()); p = parse("child-src a; worker-src b"); q = parse("child-src c; frame-src d"); p.union(q); - assertEquals("child-src a c; worker-src b; frame-src a d", p.show()); + assertEquals("child-src a c; frame-src a d", p.show()); p = parse("child-src b; worker-src a"); q = parse("child-src a"); p.union(q); - assertEquals("child-src b a; worker-src a", p.show()); + assertEquals("child-src b a", p.show()); p = parse("default-src a"); q = parse("worker-src b; frame-src b;"); p.union(q); - assertEquals("default-src a; frame-src a b; worker-src a b", p.show()); + assertEquals("frame-src a b; worker-src a b", p.show()); } @Test public void testUnionNone() { + Policy x = Parser.parse("frame-ancestors https://foo.bar", "http://example.com"); + Policy y = Parser.parse("default-src 'none'", "http://example.com"); + x.union(y); + assertEquals("", x.show()); + + x = Parser.parse("frame-ancestors https://foo.bar", "http://example.com"); + y = Parser.parse("default-src 'none'; frame-ancestors;", "http://example.com"); + x.union(y); + assertEquals("frame-ancestors https://foo.bar", x.show()); + Policy p = parse("frame-ancestors 'none'"); Policy q = parse("frame-ancestors 'self'"); p.union(q); @@ -645,7 +656,7 @@ public void testUnionNone() { p = parse("frame-ancestors 'none' 'none'"); q = parse("frame-ancestors 'self'"); p.union(q); - assertEquals("frame-ancestors 'self'", p.show()); + assertEquals("", p.show()); p = parse("frame-ancestors 'self'"); q = parse("frame-ancestors 'none'"); @@ -697,4 +708,53 @@ public void testUnionNone() { p.union(q); assertEquals("script-src *", p.show()); } + + @Test + public void testMergeCommutativity() { + String[] policies = new String[] { + "script-src 'self'", + "script-src 'none'", + "script-src a", + "script-src a custom:", + "style-src 'self'", + "style-src 'none'", + "style-src a", + "style-src a custom:", + "default-src 'none'", + "default-src a", + "default-src custom:", + "plugin-types a/b", + "frame-ancestors 'self'", + "frame-ancestors 'none'", + "frame-ancestors a", + "frame-ancestors a custom:", + "frame-ancestors custom:", + "upgrade-insecure-requests", + "script-src a; frame-ancestors b" + }; + + for (int i = 0; i < policies.length; i++) { + for (int k = 0; k < policies.length; k++) { + Policy pq = parse(policies[i]); + pq.union(parse(policies[k])); + + Policy qp = parse(policies[k]); + qp.union(parse(policies[i])); + + assertTrue(pq.show() + " ≠ " + qp.show(), pq.equals(qp)); + } + } + + for (int i = 0; i < policies.length; i++) { + for (int k = 0; k < policies.length; k++) { + Policy pq = parse(policies[i]); + pq.intersect(parse(policies[k])); + + Policy qp = parse(policies[k]); + qp.intersect(parse(policies[i])); + + assertTrue(pq.show() + " ≠ " + qp.show(), pq.equals(qp)); + } + } + } }