diff --git a/sanitize.go b/sanitize.go index d90f93d..59c2bbc 100644 --- a/sanitize.go +++ b/sanitize.go @@ -51,12 +51,12 @@ var ( // It returns a HTML string that has been sanitized by the policy or an empty // string if an error has occurred (most likely as a consequence of extremely // malformed input) -func (p *Policy) Sanitize(s string) string { +func (p *Policy) Sanitize(s string, filters ...TokenReader) string { if strings.TrimSpace(s) == "" { return s } - return p.sanitize(strings.NewReader(s)).String() + return p.sanitize(strings.NewReader(s), filters...).String() } // SanitizeBytes takes a []byte that contains a HTML fragment or document and applies @@ -65,12 +65,12 @@ func (p *Policy) Sanitize(s string) string { // It returns a []byte containing the HTML that has been sanitized by the policy // or an empty []byte if an error has occurred (most likely as a consequence of // extremely malformed input) -func (p *Policy) SanitizeBytes(b []byte) []byte { +func (p *Policy) SanitizeBytes(b []byte, filters ...TokenReader) []byte { if len(bytes.TrimSpace(b)) == 0 { return b } - return p.sanitize(bytes.NewReader(b)).Bytes() + return p.sanitize(bytes.NewReader(b), filters...).Bytes() } // SanitizeReader takes an io.Reader that contains a HTML fragment or document @@ -78,12 +78,12 @@ func (p *Policy) SanitizeBytes(b []byte) []byte { // // It returns a bytes.Buffer containing the HTML that has been sanitized by the // policy. Errors during sanitization will merely return an empty result. -func (p *Policy) SanitizeReader(r io.Reader) *bytes.Buffer { - return p.sanitize(r) +func (p *Policy) SanitizeReader(r io.Reader, filters ...TokenReader) *bytes.Buffer { + return p.sanitize(r, filters...) } // Performs the actual sanitization process. -func (p *Policy) sanitize(r io.Reader) *bytes.Buffer { +func (p *Policy) sanitize(r io.Reader, filters ...TokenReader) *bytes.Buffer { // It is possible that the developer has created the policy via: // p := bluemonday.Policy{} @@ -100,10 +100,16 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer { skipClosingTag bool closingTagToSkipStack []string mostRecentlyStartedToken string + reader TokenReader ) - tokenizer := html.NewTokenizer(r) - reader := tokenizerReader{tokenizer} + // Chain together TokenReader filters + reader = &tokenizerReader{html.NewTokenizer(r)} + for _, f := range filters { + f.Source(reader) + reader = f + } + for { token, err := reader.Token() if token == nil { diff --git a/token.go b/token.go index 339dd70..384df21 100644 --- a/token.go +++ b/token.go @@ -32,6 +32,7 @@ package bluemonday import "golang.org/x/net/html" type TokenReader interface { + Source(source TokenReader) Token() (*html.Token, error) } @@ -48,3 +49,7 @@ func (r *tokenizerReader) Token() (*html.Token, error) { token := r.Tokenizer.Token() return &token, nil } + +// Source is a no-op for tokenizerReader +func (r *tokenizerReader) Source(TokenReader) { +} diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..24a51d3 --- /dev/null +++ b/token_test.go @@ -0,0 +1,48 @@ +package bluemonday + +import ( + "testing" + + "golang.org/x/net/html" + "golang.org/x/net/html/atom" +) + +type testRemoverReader struct { + source TokenReader + tagAtom atom.Atom +} + +func (r *testRemoverReader) Token() (*html.Token, error) { + t, err := r.source.Token() + if err != nil { + return t, err + } + if (t.Type == html.StartTagToken || t.Type == html.EndTagToken) && t.DataAtom == r.tagAtom { + // Skip bold, return next token + return r.source.Token() + } + return t, nil +} + +func (r *testRemoverReader) Source(s TokenReader) { + r.source = s +} + +func TestTokenReader(t *testing.T) { + p := UGCPolicy() + + input := "

A bold statement.

" + want := "

A bold statement.

" + got := p.Sanitize(input) + if got != want { + t.Errorf("got: %q, want: %q", got, want) + } + + removeBold := &testRemoverReader{tagAtom: atom.B} + input = "

A bold statement.

" + want = "

A bold statement.

" + got = p.Sanitize(input, removeBold) + if got != want { + t.Errorf("got: %q, want: %q", got, want) + } +}