diff --git a/automod/triggers.go b/automod/triggers.go index ca3e2e3f73..0fef5cb18b 100644 --- a/automod/triggers.go +++ b/automod/triggers.go @@ -996,9 +996,10 @@ func (r *MessageRegexTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstat ///////////////////////////////////////////////////////////// type SpamTriggerData struct { - Treshold int - TimeLimit int - SanitizeText bool + Treshold int + TimeLimit int + SanitizeText bool + CrossChannelMatch bool } var _ MessageTrigger = (*SpamTrigger)(nil) @@ -1045,6 +1046,12 @@ func (spam *SpamTrigger) UserSettings() []*SettingDef { Kind: SettingTypeBool, Default: false, }, + { + Name: "Match duplicates across channels", + Key: "CrossChannelMatch", + Kind: SettingTypeBool, + Default: false, + }, } } @@ -1058,7 +1065,12 @@ func (spam *SpamTrigger) CheckMessage(triggerCtx *TriggerContext, cs *dstate.Cha timeLimit := time.Now().Add(-time.Second * time.Duration(settingsCast.TimeLimit)) - messages := bot.State.GetMessages(cs.GuildID, cs.ID, &dstate.MessagesQuery{ + var channelID int64 = 0 + if !settingsCast.CrossChannelMatch { + channelID = cs.ID + } + + messages := bot.State.GetMessages(cs.GuildID, channelID, &dstate.MessagesQuery{ Limit: 1000, }) diff --git a/lib/dstate/inmemorytracker/accessors.go b/lib/dstate/inmemorytracker/accessors.go index 0176454d25..c3ff5c2086 100644 --- a/lib/dstate/inmemorytracker/accessors.go +++ b/lib/dstate/inmemorytracker/accessors.go @@ -1,6 +1,8 @@ package inmemorytracker import ( + "container/list" + "github.com/botlabs-gg/yagpdb/v2/lib/discordgo" "github.com/botlabs-gg/yagpdb/v2/lib/dstate" ) @@ -139,7 +141,21 @@ func (tracker *InMemoryTracker) GetMessages(guildID int64, channelID int64, quer shard.mu.RLock() defer shard.mu.RUnlock() - messages := shard.messages[channelID] + var messages *list.List + var convert func(*list.Element) *dstate.MessageState + + if channelID == 0 { + messages = shard.guildMessages[guildID] + convert = func(e *list.Element) *dstate.MessageState { + return (*e.Value.(*any)).(*dstate.MessageState) + } + } else { + messages = shard.channelMessages[channelID] + convert = func(e *list.Element) *dstate.MessageState { + return e.Value.(*dstate.MessageState) + } + } + if messages == nil { return nil } @@ -158,7 +174,7 @@ func (tracker *InMemoryTracker) GetMessages(guildID int64, channelID int64, quer i := 0 for e := messages.Back(); e != nil; e = e.Prev() { - cast := e.Value.(*dstate.MessageState) + cast := convert(e) include, cont := checkMessage(query, cast) if include { buf[i] = cast diff --git a/lib/dstate/inmemorytracker/gc.go b/lib/dstate/inmemorytracker/gc.go index 8595fe5492..2ff4492bac 100644 --- a/lib/dstate/inmemorytracker/gc.go +++ b/lib/dstate/inmemorytracker/gc.go @@ -57,37 +57,51 @@ func (shard *ShardTracker) gcGuild(t time.Time, gs *SparseGuildState) { shard.gcGuildChannel(t, gs, v.ID, limitLen, limitAge) } + shard.gcMessageList(t, gs, limitLen, limitAge, shard.guildMessages[gs.Guild.ID], func(elem *list.Element) *dstate.MessageState { + return (*elem.Value.(*any)).(*dstate.MessageState) + }) + if shard.conf.RemoveOfflineMembersAfter > 0 { shard.gcMembers(t, gs, shard.conf.RemoveOfflineMembersAfter) } } func (shard *ShardTracker) gcGuildChannel(t time.Time, gs *SparseGuildState, channel int64, maxLen int, maxAge time.Duration) { - if messages, ok := shard.messages[channel]; ok { - if maxLen > 0 { - overflow := messages.Len() - maxLen - for i := overflow; i > 0; i-- { - messages.Remove(messages.Front()) - } + if messages, ok := shard.channelMessages[channel]; ok { + shard.gcMessageList(t, gs, maxLen, maxAge, messages, func(elem *list.Element) *dstate.MessageState { + return elem.Value.(*dstate.MessageState) + }) + } +} + +func (shard *ShardTracker) gcMessageList(t time.Time, gs *SparseGuildState, maxLen int, maxAge time.Duration, messages *list.List, convert func(*list.Element) *dstate.MessageState) { + if messages == nil { + return + } + + if maxLen > 0 { + overflow := messages.Len() - maxLen + for i := overflow; i > 0; i-- { + messages.Remove(messages.Front()) } + } - if maxAge > 0 { - if oldest := messages.Front(); oldest != nil { - v := oldest.Value.(*dstate.MessageState) - age := t.Sub(v.ParsedCreatedAt) + if maxAge > 0 { + if oldest := messages.Front(); oldest != nil { + v := convert(oldest) + age := t.Sub(v.ParsedCreatedAt) - if age > maxAge { - shard.gcMessagesAge(t, gs, channel, maxAge, messages) - } + if age > maxAge { + shard.gcMessagesAge(t, gs, maxAge, messages, convert) } } } } -func (shard *ShardTracker) gcMessagesAge(t time.Time, gs *SparseGuildState, channel int64, maxAge time.Duration, messages *list.List) { +func (shard *ShardTracker) gcMessagesAge(t time.Time, gs *SparseGuildState, maxAge time.Duration, messages *list.List, convert func(*list.Element) *dstate.MessageState) { toDel := make([]*list.Element, 0, 100) for e := messages.Front(); e != nil; e = e.Next() { - v := e.Value.(*dstate.MessageState) + v := convert(e) age := t.Sub(v.ParsedCreatedAt) if age > maxAge { diff --git a/lib/dstate/inmemorytracker/gc_test.go b/lib/dstate/inmemorytracker/gc_test.go index ab78fbfda6..7334fb52b5 100644 --- a/lib/dstate/inmemorytracker/gc_test.go +++ b/lib/dstate/inmemorytracker/gc_test.go @@ -1,6 +1,7 @@ package inmemorytracker import ( + "container/list" "testing" "time" @@ -35,35 +36,59 @@ func TestGCMessages(t *testing.T) { }) // verify the contents now - verifyMessages(t, state, initialTestChannelID, []int64{10000, 10001}) + verifyChannelMessages(t, state, initialTestChannelID, []int64{10000, 10001}) + verifyGuildMessages(t, state, initialTestGuildID, []int64{10000, 10001}) // add another message that will be GC'd soon state.HandleEvent(testSession, &discordgo.MessageCreate{ Message: createTestMessage(10002, time.Date(2021, 5, 20, 10, 0, 4, 0, time.UTC)), }) - verifyMessages(t, state, initialTestChannelID, []int64{10000, 10001, 10002}) + verifyChannelMessages(t, state, initialTestChannelID, []int64{10000, 10001, 10002}) + verifyGuildMessages(t, state, initialTestGuildID, []int64{10000, 10001, 10002}) // run a gc, verifying max len works shard.gcTick(time.Date(2021, 5, 20, 10, 0, 2, 0, time.UTC), nil) - verifyMessages(t, state, initialTestChannelID, []int64{10001, 10002}) + verifyChannelMessages(t, state, initialTestChannelID, []int64{10001, 10002}) + verifyGuildMessages(t, state, initialTestGuildID, []int64{10001, 10002}) // run a gc verifying max age shard.gcTick(time.Date(2021, 5, 20, 11, 0, 3, 0, time.UTC), nil) - verifyMessages(t, state, initialTestChannelID, []int64{10002}) + verifyChannelMessages(t, state, initialTestChannelID, []int64{10002}) + verifyGuildMessages(t, state, initialTestGuildID, []int64{10002}) // run yet another one because why not shard.gcTick(time.Date(2021, 5, 20, 12, 0, 3, 0, time.UTC), nil) - verifyMessages(t, state, initialTestChannelID, []int64{}) + verifyChannelMessages(t, state, initialTestChannelID, []int64{}) + verifyGuildMessages(t, state, initialTestGuildID, []int64{}) } -func verifyMessages(t *testing.T, state *InMemoryTracker, channelID int64, expectedResult []int64) { +func verifyChannelMessages(t *testing.T, state *InMemoryTracker, channelID int64, expectedResult []int64) { shard := state.getShard(0) - messages, ok := shard.messages[channelID] + messages, ok := shard.channelMessages[channelID] if !ok { t.Fatal("emessages slice not present") } + verifyMessages(t, state, messages, expectedResult, func(e *list.Element) *dstate.MessageState { + return e.Value.(*dstate.MessageState) + }) +} + +func verifyGuildMessages(t *testing.T, state *InMemoryTracker, guildID int64, expectedResult []int64) { + shard := state.getShard(0) + + messages, ok := shard.guildMessages[guildID] + if !ok { + t.Fatal("guild messages slice not present") + } + + verifyMessages(t, state, messages, expectedResult, func(e *list.Element) *dstate.MessageState { + return (*e.Value.(*any)).(*dstate.MessageState) + }) +} + +func verifyMessages(t *testing.T, state *InMemoryTracker, messages *list.List, expectedResult []int64, convert func(*list.Element) *dstate.MessageState) { if messages.Len() != len(expectedResult) { t.Fatalf("mismatched lengths, got: %d, expected: %d", messages.Len(), len(expectedResult)) } @@ -71,7 +96,7 @@ func verifyMessages(t *testing.T, state *InMemoryTracker, channelID int64, expec i := 0 for e := messages.Front(); e != nil; e = e.Next() { - cast := e.Value.(*dstate.MessageState) + cast := convert(e) if cast.ID != expectedResult[i] { t.Fatalf("mismatched result at index [%d]: %d, expected %d", i, cast.ID, expectedResult[i]) } diff --git a/lib/dstate/inmemorytracker/tracker.go b/lib/dstate/inmemorytracker/tracker.go index 03b3174cc1..357927e623 100644 --- a/lib/dstate/inmemorytracker/tracker.go +++ b/lib/dstate/inmemorytracker/tracker.go @@ -157,18 +157,24 @@ type ShardTracker struct { members map[int64]map[int64]*WrappedMember // Key is ChannelID - messages map[int64]*list.List + channelMessages map[int64]*list.List + + // Key is GuildID + // Maintains snapshot of most recent messages in the guild + // from any channel + guildMessages map[int64]*list.List conf TrackerConfig } func newShard(conf TrackerConfig, id int) *ShardTracker { return &ShardTracker{ - shardID: id, - guilds: make(map[int64]*SparseGuildState), - members: make(map[int64]map[int64]*WrappedMember), - messages: make(map[int64]*list.List), - conf: conf, + shardID: id, + guilds: make(map[int64]*SparseGuildState), + members: make(map[int64]map[int64]*WrappedMember), + channelMessages: make(map[int64]*list.List), + guildMessages: make(map[int64]*list.List), + conf: conf, } } @@ -302,6 +308,7 @@ func (shard *ShardTracker) handleGuildCreate(gc *discordgo.GuildCreate) { } shard.guilds[gc.ID] = guildState + shard.guildMessages[gc.ID] = list.New() for _, v := range gc.Members { // problem: the presences in guild does not include a full user object @@ -358,10 +365,11 @@ func (shard *ShardTracker) handleGuildDelete(gd *discordgo.GuildDelete) { } else { if existing, ok := shard.guilds[gd.ID]; ok { for _, v := range existing.Channels { - delete(shard.messages, v.ID) + delete(shard.channelMessages, v.ID) } } + delete(shard.guildMessages, gd.ID) delete(shard.members, gd.ID) delete(shard.guilds, gd.ID) } @@ -427,7 +435,7 @@ func (shard *ShardTracker) handleChannelDelete(c *discordgo.ChannelDelete) { shard.mu.Lock() defer shard.mu.Unlock() - delete(shard.messages, c.ID) + delete(shard.channelMessages, c.ID) gs, ok := shard.guilds[c.GuildID] if !ok { @@ -625,13 +633,19 @@ func (shard *ShardTracker) handleMessageCreate(m *discordgo.MessageCreate) { return } - if cl, ok := shard.messages[m.ChannelID]; ok { - cl.PushBack(dstate.MessageStateFromDgo(m.Message)) + var elem *list.Element + + if cl, ok := shard.channelMessages[m.ChannelID]; ok { + elem = cl.PushBack(dstate.MessageStateFromDgo(m.Message)) } else { cl := list.New() - cl.PushBack(dstate.MessageStateFromDgo(m.Message)) - shard.messages[m.ChannelID] = cl + elem = cl.PushBack(dstate.MessageStateFromDgo(m.Message)) + shard.channelMessages[m.ChannelID] = cl } + + // Insert *list.Element.Value into guildMessages so that we only need to perform + // state changes for the channel lists + shard.guildMessages[m.GuildID].PushBack(&elem.Value) } func (shard *ShardTracker) handleMessageUpdate(m *discordgo.MessageUpdate) { @@ -642,7 +656,7 @@ func (shard *ShardTracker) handleMessageUpdate(m *discordgo.MessageUpdate) { return } - if cl, ok := shard.messages[m.ChannelID]; ok { + if cl, ok := shard.channelMessages[m.ChannelID]; ok { for e := cl.Back(); e != nil; e = e.Prev() { // do something with e.Value cast := e.Value.(*dstate.MessageState) @@ -698,7 +712,7 @@ func (shard *ShardTracker) handleMessageDelete(m *discordgo.MessageDelete) { return } - if cl, ok := shard.messages[m.ChannelID]; ok { + if cl, ok := shard.channelMessages[m.ChannelID]; ok { for e := cl.Back(); e != nil; e = e.Prev() { cast := e.Value.(*dstate.MessageState) @@ -720,7 +734,7 @@ func (shard *ShardTracker) handleMessageDeleteBulk(m *discordgo.MessageDeleteBul return } - if cl, ok := shard.messages[m.ChannelID]; ok { + if cl, ok := shard.channelMessages[m.ChannelID]; ok { for e := cl.Back(); e != nil; e = e.Prev() { cast := e.Value.(*dstate.MessageState) @@ -853,7 +867,8 @@ func (shard *ShardTracker) handleEmojis(e *discordgo.GuildEmojisUpdate) { func (shard *ShardTracker) reset() { shard.guilds = make(map[int64]*SparseGuildState) shard.members = make(map[int64]map[int64]*WrappedMember) - shard.messages = make(map[int64]*list.List) + shard.channelMessages = make(map[int64]*list.List) + shard.guildMessages = make(map[int64]*list.List) } /////////////////// @@ -899,7 +914,7 @@ func (shard *ShardTracker) handleThreadDelete(td *discordgo.ThreadDelete) { } func (shard *ShardTracker) removeThread(guildID int64, threadID int64) { - delete(shard.messages, threadID) + delete(shard.channelMessages, threadID) gs, ok := shard.guilds[guildID] if !ok { diff --git a/lib/dstate/interface.go b/lib/dstate/interface.go index e6ef05b63e..83352c223c 100644 --- a/lib/dstate/interface.go +++ b/lib/dstate/interface.go @@ -30,6 +30,7 @@ type StateTracker interface { // GetMessages returns the messages of the channel, up to limit, you may pass in a pre-allocated buffer to save allocations. // If cap(buf) is less than the needed then a new one will be created and returned // if len(buf) is greater than needed, it will be sliced to the proper length + // If channelID is 0, it will attempt to return the most recent messages from the guild or nil GetMessages(guildID int64, channelID int64, query *MessagesQuery) []*MessageState // Calls f on all members, return true to continue or false to stop