Skip to content

Commit

Permalink
optimize sample loader
Browse files Browse the repository at this point in the history
this is done by replacing chan-based fake iterator by modern (real) iterator
  • Loading branch information
umputun committed Jan 8, 2025
1 parent 4941fff commit 8213019
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 27 deletions.
43 changes: 21 additions & 22 deletions lib/tgspam/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"iter"
"log"
"math"
"net/http"
Expand Down Expand Up @@ -355,14 +356,14 @@ func (d *Detector) LoadSamples(exclReader io.Reader, spamReaders, hamReaders []i
d.classifier.reset()

// excluded tokens should be loaded before spam samples to exclude them from spam tokenization
for t := range d.tokenChan(exclReader) {
for t := range d.tokenIterator(exclReader) {
d.excludedTokens = append(d.excludedTokens, strings.ToLower(t))
}
lr := LoadResult{ExcludedTokens: len(d.excludedTokens)}

// load spam samples and update the classifier with them
docs := []document{}
for token := range d.tokenChan(spamReaders...) {
for token := range d.tokenIterator(spamReaders...) {
tokenizedSpam := d.tokenize(token)
d.tokenizedSpam = append(d.tokenizedSpam, tokenizedSpam) // add to list of samples
tokens := make([]string, 0, len(tokenizedSpam))
Expand All @@ -374,7 +375,7 @@ func (d *Detector) LoadSamples(exclReader io.Reader, spamReaders, hamReaders []i
}

// load ham samples and update the classifier with them
for token := range d.tokenChan(hamReaders...) {
for token := range d.tokenIterator(hamReaders...) {
tokenizedSpam := d.tokenize(token)
tokens := make([]string, 0, len(tokenizedSpam))
for token := range tokenizedSpam {
Expand All @@ -394,7 +395,7 @@ func (d *Detector) LoadStopWords(readers ...io.Reader) (LoadResult, error) {
defer d.lock.Unlock()

d.stopWords = []string{}
for t := range d.tokenChan(readers...) {
for t := range d.tokenIterator(readers...) {
d.stopWords = append(d.stopWords, strings.ToLower(t))
}
return LoadResult{StopWords: len(d.stopWords)}, nil
Expand Down Expand Up @@ -423,7 +424,7 @@ func (d *Detector) updateSample(msg string, upd SampleUpdater, sc spamClass) err

// load samples and update the classifier with them
docs := []document{}
for token := range d.tokenChan(bytes.NewBufferString(msg)) {
for token := range d.tokenIterator(bytes.NewBufferString(msg)) {
tokenizedSample := d.tokenize(token)
tokens := make([]string, 0, len(tokenizedSample))
for token := range tokenizedSample {
Expand All @@ -435,14 +436,10 @@ func (d *Detector) updateSample(msg string, upd SampleUpdater, sc spamClass) err
return nil
}

// tokenChan parses readers and returns a channel of tokens.
// tokenIterator parses readers and returns an iterator of tokens.
// A line per-token or comma-separated "tokens" supported
func (d *Detector) tokenChan(readers ...io.Reader) <-chan string {
resCh := make(chan string)

go func() {
defer close(resCh)

func (d *Detector) tokenIterator(readers ...io.Reader) iter.Seq[string] {
return func(yield func(string) bool) {
for _, reader := range readers {
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
Expand All @@ -453,25 +450,27 @@ func (d *Detector) tokenChan(readers ...io.Reader) <-chan string {
for _, token := range lineTokens {
cleanToken := strings.Trim(token, " \"\n\r\t")
if cleanToken != "" {
resCh <- cleanToken
if !yield(cleanToken) {
return
}
}
}
} else {
// each line with a single token
cleanToken := strings.Trim(line, " \n\r\t")
if cleanToken != "" {
if !yield(cleanToken) {
return
}
}
continue
}
// each line with a single token
cleanToken := strings.Trim(line, " \n\r\t")
if cleanToken != "" {
resCh <- cleanToken
}
}

if err := scanner.Err(); err != nil {
log.Printf("[WARN] failed to read tokens, error=%v", err)
}
}
}()

return resCh
}
}

// tokenize takes a string and returns a map where the keys are unique words (tokens)
Expand Down
198 changes: 193 additions & 5 deletions lib/tgspam/detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,83 @@ func TestDetector_ApprovedUsers(t *testing.T) {

}

func TestDetector_LoadSamples(t *testing.T) {
t.Run("basic loading", func(t *testing.T) {
d := NewDetector(Config{})
spamSamples := strings.NewReader("win free iPhone\nlottery prize xyz XyZ")
hamSamples := strings.NewReader("hello world\nhow are you\nhave a good day")
exclSamples := strings.NewReader("xyz")

lr, err := d.LoadSamples(exclSamples, []io.Reader{spamSamples}, []io.Reader{hamSamples})

require.NoError(t, err)
assert.Equal(t, 1, lr.ExcludedTokens)
assert.Equal(t, 2, lr.SpamSamples)
assert.Equal(t, 3, lr.HamSamples)

// verify excluded tokens
assert.Contains(t, d.excludedTokens, "xyz")

// verify tokenized spam samples
assert.Len(t, d.tokenizedSpam, 2)
assert.Contains(t, d.tokenizedSpam[0], "win")
assert.Contains(t, d.tokenizedSpam[1], "lottery")

// verify classifier learning
assert.Equal(t, 5, d.classifier.nAllDocument)
assert.Contains(t, d.classifier.learningResults, "win")
assert.Contains(t, d.classifier.learningResults["win"], spamClass("spam"))
assert.Contains(t, d.classifier.learningResults, "world")
assert.Contains(t, d.classifier.learningResults["world"], spamClass("ham"))

// verify excluded tokens in learning results
assert.NotContains(t, d.classifier.learningResults, "xyz", "excluded token should not be in learning results")
assert.NotContains(t, d.classifier.learningResults, "XyZ", "excluded token should not be in learning results")
})

t.Run("empty samples", func(t *testing.T) {
d := NewDetector(Config{})
exclSamples := strings.NewReader("")
spamSamples := strings.NewReader("")
hamSamples := strings.NewReader("")

lr, err := d.LoadSamples(exclSamples, []io.Reader{spamSamples}, []io.Reader{hamSamples})

require.NoError(t, err)
assert.Equal(t, 0, lr.ExcludedTokens)
assert.Equal(t, 0, lr.SpamSamples)
assert.Equal(t, 0, lr.HamSamples)
assert.Equal(t, 0, d.classifier.nAllDocument)
})

t.Run("multiple readers", func(t *testing.T) {
d := NewDetector(Config{})
exclSamples := strings.NewReader(`"xy", "z", "the"`)
spamSamples1 := strings.NewReader("win free iPhone")
spamSamples2 := strings.NewReader("lottery prize xyz")
hamsSamples1 := strings.NewReader("hello world\nhow are you\nhave a good day")
hamsSamples2 := strings.NewReader("some other text\nwith more words")

lr, err := d.LoadSamples(
exclSamples,
[]io.Reader{spamSamples1, spamSamples2},
[]io.Reader{hamsSamples1, hamsSamples2},
)

require.NoError(t, err)
assert.Equal(t, 3, lr.ExcludedTokens)
assert.Equal(t, d.excludedTokens, []string{"xy", "z", "the"})
assert.Equal(t, 2, lr.SpamSamples)
assert.Equal(t, 5, lr.HamSamples)
t.Logf("Learning results: %+v", d.classifier.learningResults)
assert.Equal(t, 7, d.classifier.nAllDocument)
assert.Contains(t, d.classifier.learningResults["win"], spamClass("spam"))
assert.Contains(t, d.classifier.learningResults["prize"], spamClass("spam"))
assert.Contains(t, d.classifier.learningResults["world"], spamClass("ham"))
assert.Contains(t, d.classifier.learningResults["some"], spamClass("ham"))
})
}

func TestDetector_tokenize(t *testing.T) {
tests := []struct {
name string
Expand All @@ -1166,7 +1243,7 @@ func TestDetector_tokenize(t *testing.T) {
}
}

func TestDetector_tokenChan(t *testing.T) {
func TestDetector_tokenIterator(t *testing.T) {
tests := []struct {
name string
input string
Expand All @@ -1184,7 +1261,7 @@ func TestDetector_tokenChan(t *testing.T) {
d := Detector{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ch := d.tokenChan(bytes.NewBufferString(tt.input))
ch := d.tokenIterator(bytes.NewBufferString(tt.input))
res := []string{}
for token := range ch {
res = append(res, token)
Expand All @@ -1194,14 +1271,15 @@ func TestDetector_tokenChan(t *testing.T) {
}
}

func TestDetector_tokenChanMultipleReaders(t *testing.T) {
func TestDetector_tokenIteratorMultipleReaders(t *testing.T) {
d := Detector{}
ch := d.tokenChan(bytes.NewBufferString("hello\nworld"), bytes.NewBufferString("something, new"))
ch := d.tokenIterator(bytes.NewBufferString("hello\nworld"), bytes.NewBufferString("something, new"))
res := []string{}
for token := range ch {
res = append(res, token)
}
assert.Equal(t, []string{"hello", "world", "something, new"}, res)
sort.Strings(res)
assert.Equal(t, []string{"hello", "something, new", "world"}, res)
}

func TestCleanText(t *testing.T) {
Expand Down Expand Up @@ -1318,3 +1396,113 @@ func Test_cleanEmoji(t *testing.T) {
})
}
}

func BenchmarkTokenize(b *testing.B) {
d := &Detector{
excludedTokens: []string{"the", "and", "or", "but", "in", "on", "at", "to"},
}

tests := []struct {
name string
text string
}{
{
name: "Short_NoExcluded",
text: "hello world test message",
},
{
name: "Short_WithExcluded",
text: "the quick brown fox and the lazy dog",
},
{
name: "Medium_Mixed",
text: strings.Repeat("hello world and test message with some excluded tokens ", 10),
},
{
name: "Long_MixedWithPunct",
text: strings.Repeat("hello, world! test? message. with!! some... excluded tokens!!! ", 50),
},
{
name: "WithEmoji",
text: "hello 👋 world 🌍 test 🧪 message 📝 with emoji 😊",
},
{
name: "RealWorldSample",
text: "🔥 EXCLUSIVE OFFER! Don't miss out on this amazing deal. Buy now and get 50% OFF! Limited time offer. Click here: http://example.com #deal #shopping #discount",
},
}

for _, tc := range tests {
b.Run(tc.name, func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = d.tokenize(tc.text)
}
})
}
}

func BenchmarkLoadSamples(b *testing.B) {
makeReader := func(lines []string) io.Reader {
return strings.NewReader(strings.Join(lines, "\n"))
}

tests := []struct {
name string
spam []string
ham []string
excluded []string
}{
{
name: "Small",
spam: []string{"spam message 1", "buy now spam 2", "spam offer 3"},
ham: []string{"hello world", "normal message", "how are you"},
excluded: []string{"the", "and", "or"},
},
{
name: "Medium",
spam: []string{"spam message 1", "buy now spam 2", "spam offer 3", "urgent offer", "free money"},
ham: []string{"hello world", "normal message", "how are you", "meeting tomorrow", "project update"},
excluded: []string{"the", "and", "or", "but", "in", "on", "at"},
},
{
name: "Large_RealWorld",
// use actual spam samples from your data
spam: []string{
"Здравствуйте Мы занимаемая новым видом заработка в интернете Наша сфера даст вам опыт, знания",
"У кого нет карты карты Тинькофф? Можете оформить по моей ссылке и получите 500р от меня",
"😀😀😀 Для тeх ктo ищeт дoпoлнительный доход предлагаю перспективный и прибыльный зaрaботok",
},
ham: []string{
"When is our next meeting?",
"Here's the project update you requested",
"Thanks for the feedback, I'll review it",
},
excluded: []string{"the", "and", "or", "but", "in", "on", "at", "to", "for", "with"},
},
}

for _, tc := range tests {
b.Run(tc.name, func(b *testing.B) {
d := NewDetector(Config{})
spamReader := makeReader(tc.spam)
hamReader := makeReader(tc.ham)
exclReader := makeReader(tc.excluded)

b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// need to rewind readers for each iteration
spamReader = makeReader(tc.spam)
hamReader = makeReader(tc.ham)
exclReader = makeReader(tc.excluded)

_, err := d.LoadSamples(exclReader, []io.Reader{spamReader}, []io.Reader{hamReader})
if err != nil {
b.Fatal(err)
}
}
})
}
}

0 comments on commit 8213019

Please sign in to comment.