Skip to content

Commit

Permalink
Adds ability for automod to detect duplicate messages across channels (
Browse files Browse the repository at this point in the history
…#1663)

* Changing in-memory tracker to support global guild messages

* Updates tracker to store guild message lists

* Adds cross-channel option for duplicate messages

* Switching inmemorytracker to use original list representation; adds inmemory guild message snapshot
  • Loading branch information
KTStephano authored Jun 16, 2024
1 parent 6b7cc8f commit 0afcbf7
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 46 deletions.
20 changes: 16 additions & 4 deletions automod/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -996,9 +996,10 @@ func (r *MessageRegexTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstat
/////////////////////////////////////////////////////////////

type SpamTriggerData struct {
Treshold int
TimeLimit int
SanitizeText bool
Treshold int
TimeLimit int
SanitizeText bool
CrossChannelMatch bool
}

var _ MessageTrigger = (*SpamTrigger)(nil)
Expand Down Expand Up @@ -1045,6 +1046,12 @@ func (spam *SpamTrigger) UserSettings() []*SettingDef {
Kind: SettingTypeBool,
Default: false,
},
{
Name: "Match duplicates across channels",
Key: "CrossChannelMatch",
Kind: SettingTypeBool,
Default: false,
},
}
}

Expand All @@ -1058,7 +1065,12 @@ func (spam *SpamTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.Cha

timeLimit := time.Now().Add(-time.Second * time.Duration(settingsCast.TimeLimit))

messages := bot.State.GetMessages(cs.GuildID, cs.ID, &dstate.MessagesQuery{
var channelID int64 = 0
if !settingsCast.CrossChannelMatch {
channelID = cs.ID
}

messages := bot.State.GetMessages(cs.GuildID, channelID, &dstate.MessagesQuery{
Limit: 1000,
})

Expand Down
20 changes: 18 additions & 2 deletions lib/dstate/inmemorytracker/accessors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package inmemorytracker

import (
"container/list"

"github.com/botlabs-gg/yagpdb/v2/lib/discordgo"
"github.com/botlabs-gg/yagpdb/v2/lib/dstate"
)
Expand Down Expand Up @@ -139,7 +141,21 @@ func (tracker *InMemoryTracker) GetMessages(guildID int64, channelID int64, quer
shard.mu.RLock()
defer shard.mu.RUnlock()

messages := shard.messages[channelID]
var messages *list.List
var convert func(*list.Element) *dstate.MessageState

if channelID == 0 {
messages = shard.guildMessages[guildID]
convert = func(e *list.Element) *dstate.MessageState {
return (*e.Value.(*any)).(*dstate.MessageState)
}
} else {
messages = shard.channelMessages[channelID]
convert = func(e *list.Element) *dstate.MessageState {
return e.Value.(*dstate.MessageState)
}
}

if messages == nil {
return nil
}
Expand All @@ -158,7 +174,7 @@ func (tracker *InMemoryTracker) GetMessages(guildID int64, channelID int64, quer

i := 0
for e := messages.Back(); e != nil; e = e.Prev() {
cast := e.Value.(*dstate.MessageState)
cast := convert(e)
include, cont := checkMessage(query, cast)
if include {
buf[i] = cast
Expand Down
44 changes: 29 additions & 15 deletions lib/dstate/inmemorytracker/gc.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,51 @@ func (shard *ShardTracker) gcGuild(t time.Time, gs *SparseGuildState) {
shard.gcGuildChannel(t, gs, v.ID, limitLen, limitAge)
}

shard.gcMessageList(t, gs, limitLen, limitAge, shard.guildMessages[gs.Guild.ID], func(elem *list.Element) *dstate.MessageState {
return (*elem.Value.(*any)).(*dstate.MessageState)
})

if shard.conf.RemoveOfflineMembersAfter > 0 {
shard.gcMembers(t, gs, shard.conf.RemoveOfflineMembersAfter)
}
}

func (shard *ShardTracker) gcGuildChannel(t time.Time, gs *SparseGuildState, channel int64, maxLen int, maxAge time.Duration) {
if messages, ok := shard.messages[channel]; ok {
if maxLen > 0 {
overflow := messages.Len() - maxLen
for i := overflow; i > 0; i-- {
messages.Remove(messages.Front())
}
if messages, ok := shard.channelMessages[channel]; ok {
shard.gcMessageList(t, gs, maxLen, maxAge, messages, func(elem *list.Element) *dstate.MessageState {
return elem.Value.(*dstate.MessageState)
})
}
}

func (shard *ShardTracker) gcMessageList(t time.Time, gs *SparseGuildState, maxLen int, maxAge time.Duration, messages *list.List, convert func(*list.Element) *dstate.MessageState) {
if messages == nil {
return
}

if maxLen > 0 {
overflow := messages.Len() - maxLen
for i := overflow; i > 0; i-- {
messages.Remove(messages.Front())
}
}

if maxAge > 0 {
if oldest := messages.Front(); oldest != nil {
v := oldest.Value.(*dstate.MessageState)
age := t.Sub(v.ParsedCreatedAt)
if maxAge > 0 {
if oldest := messages.Front(); oldest != nil {
v := convert(oldest)
age := t.Sub(v.ParsedCreatedAt)

if age > maxAge {
shard.gcMessagesAge(t, gs, channel, maxAge, messages)
}
if age > maxAge {
shard.gcMessagesAge(t, gs, maxAge, messages, convert)
}
}
}
}

func (shard *ShardTracker) gcMessagesAge(t time.Time, gs *SparseGuildState, channel int64, maxAge time.Duration, messages *list.List) {
func (shard *ShardTracker) gcMessagesAge(t time.Time, gs *SparseGuildState, maxAge time.Duration, messages *list.List, convert func(*list.Element) *dstate.MessageState) {
toDel := make([]*list.Element, 0, 100)
for e := messages.Front(); e != nil; e = e.Next() {
v := e.Value.(*dstate.MessageState)
v := convert(e)
age := t.Sub(v.ParsedCreatedAt)

if age > maxAge {
Expand Down
41 changes: 33 additions & 8 deletions lib/dstate/inmemorytracker/gc_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package inmemorytracker

import (
"container/list"
"testing"
"time"

Expand Down Expand Up @@ -35,43 +36,67 @@ func TestGCMessages(t *testing.T) {
})

// verify the contents now
verifyMessages(t, state, initialTestChannelID, []int64{10000, 10001})
verifyChannelMessages(t, state, initialTestChannelID, []int64{10000, 10001})
verifyGuildMessages(t, state, initialTestGuildID, []int64{10000, 10001})

// add another message that will be GC'd soon
state.HandleEvent(testSession, &discordgo.MessageCreate{
Message: createTestMessage(10002, time.Date(2021, 5, 20, 10, 0, 4, 0, time.UTC)),
})
verifyMessages(t, state, initialTestChannelID, []int64{10000, 10001, 10002})
verifyChannelMessages(t, state, initialTestChannelID, []int64{10000, 10001, 10002})
verifyGuildMessages(t, state, initialTestGuildID, []int64{10000, 10001, 10002})

// run a gc, verifying max len works
shard.gcTick(time.Date(2021, 5, 20, 10, 0, 2, 0, time.UTC), nil)
verifyMessages(t, state, initialTestChannelID, []int64{10001, 10002})
verifyChannelMessages(t, state, initialTestChannelID, []int64{10001, 10002})
verifyGuildMessages(t, state, initialTestGuildID, []int64{10001, 10002})

// run a gc verifying max age
shard.gcTick(time.Date(2021, 5, 20, 11, 0, 3, 0, time.UTC), nil)
verifyMessages(t, state, initialTestChannelID, []int64{10002})
verifyChannelMessages(t, state, initialTestChannelID, []int64{10002})
verifyGuildMessages(t, state, initialTestGuildID, []int64{10002})

// run yet another one because why not
shard.gcTick(time.Date(2021, 5, 20, 12, 0, 3, 0, time.UTC), nil)
verifyMessages(t, state, initialTestChannelID, []int64{})
verifyChannelMessages(t, state, initialTestChannelID, []int64{})
verifyGuildMessages(t, state, initialTestGuildID, []int64{})
}

func verifyMessages(t *testing.T, state *InMemoryTracker, channelID int64, expectedResult []int64) {
func verifyChannelMessages(t *testing.T, state *InMemoryTracker, channelID int64, expectedResult []int64) {
shard := state.getShard(0)

messages, ok := shard.messages[channelID]
messages, ok := shard.channelMessages[channelID]
if !ok {
t.Fatal("emessages slice not present")
}

verifyMessages(t, state, messages, expectedResult, func(e *list.Element) *dstate.MessageState {
return e.Value.(*dstate.MessageState)
})
}

func verifyGuildMessages(t *testing.T, state *InMemoryTracker, guildID int64, expectedResult []int64) {
shard := state.getShard(0)

messages, ok := shard.guildMessages[guildID]
if !ok {
t.Fatal("guild messages slice not present")
}

verifyMessages(t, state, messages, expectedResult, func(e *list.Element) *dstate.MessageState {
return (*e.Value.(*any)).(*dstate.MessageState)
})
}

func verifyMessages(t *testing.T, state *InMemoryTracker, messages *list.List, expectedResult []int64, convert func(*list.Element) *dstate.MessageState) {
if messages.Len() != len(expectedResult) {
t.Fatalf("mismatched lengths, got: %d, expected: %d", messages.Len(), len(expectedResult))
}

i := 0
for e := messages.Front(); e != nil; e = e.Next() {

cast := e.Value.(*dstate.MessageState)
cast := convert(e)
if cast.ID != expectedResult[i] {
t.Fatalf("mismatched result at index [%d]: %d, expected %d", i, cast.ID, expectedResult[i])
}
Expand Down
49 changes: 32 additions & 17 deletions lib/dstate/inmemorytracker/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,24 @@ type ShardTracker struct {
members map[int64]map[int64]*WrappedMember

// Key is ChannelID
messages map[int64]*list.List
channelMessages map[int64]*list.List

// Key is GuildID
// Maintains snapshot of most recent messages in the guild
// from any channel
guildMessages map[int64]*list.List

conf TrackerConfig
}

func newShard(conf TrackerConfig, id int) *ShardTracker {
return &ShardTracker{
shardID: id,
guilds: make(map[int64]*SparseGuildState),
members: make(map[int64]map[int64]*WrappedMember),
messages: make(map[int64]*list.List),
conf: conf,
shardID: id,
guilds: make(map[int64]*SparseGuildState),
members: make(map[int64]map[int64]*WrappedMember),
channelMessages: make(map[int64]*list.List),
guildMessages: make(map[int64]*list.List),
conf: conf,
}
}

Expand Down Expand Up @@ -302,6 +308,7 @@ func (shard *ShardTracker) handleGuildCreate(gc *discordgo.GuildCreate) {
}

shard.guilds[gc.ID] = guildState
shard.guildMessages[gc.ID] = list.New()

for _, v := range gc.Members {
// problem: the presences in guild does not include a full user object
Expand Down Expand Up @@ -358,10 +365,11 @@ func (shard *ShardTracker) handleGuildDelete(gd *discordgo.GuildDelete) {
} else {
if existing, ok := shard.guilds[gd.ID]; ok {
for _, v := range existing.Channels {
delete(shard.messages, v.ID)
delete(shard.channelMessages, v.ID)
}
}

delete(shard.guildMessages, gd.ID)
delete(shard.members, gd.ID)
delete(shard.guilds, gd.ID)
}
Expand Down Expand Up @@ -427,7 +435,7 @@ func (shard *ShardTracker) handleChannelDelete(c *discordgo.ChannelDelete) {
shard.mu.Lock()
defer shard.mu.Unlock()

delete(shard.messages, c.ID)
delete(shard.channelMessages, c.ID)

gs, ok := shard.guilds[c.GuildID]
if !ok {
Expand Down Expand Up @@ -625,13 +633,19 @@ func (shard *ShardTracker) handleMessageCreate(m *discordgo.MessageCreate) {
return
}

if cl, ok := shard.messages[m.ChannelID]; ok {
cl.PushBack(dstate.MessageStateFromDgo(m.Message))
var elem *list.Element

if cl, ok := shard.channelMessages[m.ChannelID]; ok {
elem = cl.PushBack(dstate.MessageStateFromDgo(m.Message))
} else {
cl := list.New()
cl.PushBack(dstate.MessageStateFromDgo(m.Message))
shard.messages[m.ChannelID] = cl
elem = cl.PushBack(dstate.MessageStateFromDgo(m.Message))
shard.channelMessages[m.ChannelID] = cl
}

// Insert *list.Element.Value into guildMessages so that we only need to perform
// state changes for the channel lists
shard.guildMessages[m.GuildID].PushBack(&elem.Value)
}

func (shard *ShardTracker) handleMessageUpdate(m *discordgo.MessageUpdate) {
Expand All @@ -642,7 +656,7 @@ func (shard *ShardTracker) handleMessageUpdate(m *discordgo.MessageUpdate) {
return
}

if cl, ok := shard.messages[m.ChannelID]; ok {
if cl, ok := shard.channelMessages[m.ChannelID]; ok {
for e := cl.Back(); e != nil; e = e.Prev() {
// do something with e.Value
cast := e.Value.(*dstate.MessageState)
Expand Down Expand Up @@ -698,7 +712,7 @@ func (shard *ShardTracker) handleMessageDelete(m *discordgo.MessageDelete) {
return
}

if cl, ok := shard.messages[m.ChannelID]; ok {
if cl, ok := shard.channelMessages[m.ChannelID]; ok {
for e := cl.Back(); e != nil; e = e.Prev() {
cast := e.Value.(*dstate.MessageState)

Expand All @@ -720,7 +734,7 @@ func (shard *ShardTracker) handleMessageDeleteBulk(m *discordgo.MessageDeleteBul
return
}

if cl, ok := shard.messages[m.ChannelID]; ok {
if cl, ok := shard.channelMessages[m.ChannelID]; ok {
for e := cl.Back(); e != nil; e = e.Prev() {
cast := e.Value.(*dstate.MessageState)

Expand Down Expand Up @@ -853,7 +867,8 @@ func (shard *ShardTracker) handleEmojis(e *discordgo.GuildEmojisUpdate) {
func (shard *ShardTracker) reset() {
shard.guilds = make(map[int64]*SparseGuildState)
shard.members = make(map[int64]map[int64]*WrappedMember)
shard.messages = make(map[int64]*list.List)
shard.channelMessages = make(map[int64]*list.List)
shard.guildMessages = make(map[int64]*list.List)
}

///////////////////
Expand Down Expand Up @@ -899,7 +914,7 @@ func (shard *ShardTracker) handleThreadDelete(td *discordgo.ThreadDelete) {
}

func (shard *ShardTracker) removeThread(guildID int64, threadID int64) {
delete(shard.messages, threadID)
delete(shard.channelMessages, threadID)

gs, ok := shard.guilds[guildID]
if !ok {
Expand Down
1 change: 1 addition & 0 deletions lib/dstate/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type StateTracker interface {
// GetMessages returns the messages of the channel, up to limit, you may pass in a pre-allocated buffer to save allocations.
// If cap(buf) is less than the needed then a new one will be created and returned
// if len(buf) is greater than needed, it will be sliced to the proper length
// If channelID is 0, it will attempt to return the most recent messages from the guild or nil
GetMessages(guildID int64, channelID int64, query *MessagesQuery) []*MessageState

// Calls f on all members, return true to continue or false to stop
Expand Down

0 comments on commit 0afcbf7

Please sign in to comment.