Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to blacklist commands per channel (as requested by the forge team) #56

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/main/java/com/tterrag/k9/commands/CommandControl.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.HashSet;
import java.util.Set;

import com.google.common.base.Strings;
import com.google.gson.reflect.TypeToken;
import com.tterrag.k9.commands.CommandControl.ControlData;
import com.tterrag.k9.commands.api.Argument;
Expand All @@ -12,6 +13,7 @@
import com.tterrag.k9.util.Requirements;
import com.tterrag.k9.util.Requirements.RequiredType;

import discord4j.common.util.Snowflake;
import discord4j.rest.util.Permission;
import lombok.Value;
import reactor.core.publisher.Mono;
Expand All @@ -31,6 +33,16 @@ public static class ControlData {

private static final Argument<String> ARG_OBJECT = new WordArgument("object", "The object to configure, be it a command name or otherwise", true);

private static final Argument<Snowflake> ARG_CHANNEL = new SimpleArgument<Snowflake>("channel", "The channel to configure", false) {
@Override
public Snowflake parse(String input) {
if (Strings.isNullOrEmpty(input)) {
return null;
}
return Snowflake.of(input);
}
};

public CommandControl() {
super("ctrl", false, ControlData::new);
}
Expand All @@ -49,12 +61,14 @@ public Mono<?> process(CommandContext ctx) {
if (ctx.hasFlag(FLAG_WHITELIST) && ctx.hasFlag(FLAG_BLACKLIST)) {
return ctx.error("Illegal flag combination: Cannot whitelist and blacklist");
}
Snowflake guild = ctx.getGuildId().get();
Snowflake channel = ctx.getArg(ARG_CHANNEL);
if (ctx.hasFlag(FLAG_WHITELIST)) {
return Mono.justOrEmpty(getData(ctx))
return Mono.justOrEmpty(getData(guild, channel))
.doOnNext(data -> data.getCommandBlacklist().remove(ctx.getArg(ARG_OBJECT)))
.then(ctx.reply("Whitelisted command."));
} else if (ctx.hasFlag(FLAG_BLACKLIST)) {
return Mono.justOrEmpty(getData(ctx))
return Mono.justOrEmpty(getData(guild, channel))
.doOnNext(data -> data.getCommandBlacklist().add(ctx.getArg(ARG_OBJECT)))
.then(ctx.reply("Blacklisted command."));
}
Expand Down
6 changes: 2 additions & 4 deletions src/main/java/com/tterrag/k9/commands/CommandTrick.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.tterrag.k9.commands;

import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
Expand All @@ -17,13 +16,11 @@
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.TypeAdapter;
import com.google.gson.reflect.TypeToken;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonWriter;
import com.tterrag.k9.K9;
import com.tterrag.k9.commands.CommandTrick.TrickData;
import com.tterrag.k9.commands.api.Argument;
import com.tterrag.k9.commands.api.Command;
Expand Down Expand Up @@ -53,6 +50,7 @@
import lombok.RequiredArgsConstructor;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClient;

Expand Down Expand Up @@ -169,7 +167,7 @@ protected TrickData handleDelegate(TrickData delegate) {
}

@Override
protected void onLoad(long guild, Map<String, TrickData> data) {
protected void onLoad(Pair<Long, Long> dataKey, Map<String, TrickData> data) {
for (String key : new HashSet<>(data.keySet())) {
if (!Patterns.VALID_TRICK_NAME.matcher(key).matches()) {
TrickData removed = data.remove(key);
Expand Down
22 changes: 18 additions & 4 deletions src/main/java/com/tterrag/k9/commands/api/CommandPersisted.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import discord4j.common.util.Snowflake;
import discord4j.core.object.entity.Guild;
import discord4j.core.object.entity.channel.Channel;
import org.apache.commons.lang3.tuple.Pair;
import reactor.core.publisher.Mono;
import reactor.util.annotation.NonNull;
import reactor.util.annotation.Nullable;
Expand All @@ -32,7 +34,7 @@ public Mono<?> onReady(ReadyContext ctx) {
return super.onReady(ctx)
.then(Mono.fromRunnable(() ->
storage = new GuildStorage<>(id -> {
T ret = newHelper(ctx.getDataFolder(), id, ctx.getGson()).fromJson(getFileName(), getDataType());
T ret = newHelper(ctx.getDataFolder(), id.getKey(), ctx.getGson()).fromJson(getFileName(), getDataType());
onLoad(id, ret);
return ret;
})));
Expand All @@ -41,14 +43,14 @@ public Mono<?> onReady(ReadyContext ctx) {
@Override
public void save(File dataFolder, Gson gson) {
if (storage != null) {
for (Entry<Long, T> e : storage.snapshot().entrySet()) {
SaveHelper<T> helper = newHelper(dataFolder, e.getKey(), gson);
for (Entry<Pair<Long, Long>, T> e : storage.snapshot().entrySet()) {
SaveHelper<T> helper = newHelper(dataFolder, e.getKey().getKey(), gson);
helper.writeJson(getFileName(), e.getValue(), getDataType());
}
}
}

protected void onLoad(long guild, T data) {
protected void onLoad(Pair<Long, Long> key, T data) {
}

private SaveHelper<T> newHelper(File root, long guild, Gson gson) {
Expand All @@ -68,12 +70,24 @@ private String getFileName() {
public final Optional<T> getData(CommandContext ctx) {
return storage.get(ctx);
}

public final Optional<T> getData(CommandContext ctx, boolean useChannel) {
return storage.get(ctx, useChannel);
}

public final T getData(Guild guild) {
return storage.get(guild);
}

public final T getData(Guild guild, Channel channel) {
return storage.get(guild, channel);
}

public final T getData(Snowflake guild) {
return storage.get(guild);
}

public final T getData(Snowflake guild, Snowflake channel) {
return storage.get(guild, channel);
}
}
176 changes: 95 additions & 81 deletions src/main/java/com/tterrag/k9/commands/api/CommandRegistrar.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,105 +64,104 @@ public CommandRegistrar(K9 k9) {
this.k9 = k9;
}

public Mono<ICommand> invokeCommand(MessageCreateEvent evt, String name, String argstr) {
Optional<ICommand> commandReq = findCommand(evt.getGuildId().orElse(null), name);
public Mono<ICommand> invokeCommand(MessageCreateEvent evt, String name, String argstrIn) {
Optional<ICommand> commandReq = findCommand(evt.getGuildId().orElse(null), evt.getMessage().getChannelId(), name);

ICommand command = commandReq.filter(c -> !c.admin() || evt.getMessage().getAuthor().map(this::isAdmin).orElse(false)).orElse(null);
if (command == null) {
return Mono.empty();
}

CommandContext ctx = new CommandContext(k9, evt);

if (!command.requirements().matches(ctx).block()) {
return evt.getMessage().getChannel()
.flatMap(c -> c.createMessage("You do not have permission to use this command!"))
.delayElement(Duration.ofSeconds(5))
.flatMap(m -> m.delete())
.thenReturn(command);
}

// evt.getMessage().getChannel().flatMap(c -> c.type()).subscribe();

argstr = Strings.nullToEmpty(argstr);

Map<Flag, String> flags = new HashMap<>();
Map<Argument<?>, String> args = new HashMap<>();

Map<Character, Flag> keyToFlag = command.getFlags().stream().collect(Collectors.toMap(Flag::name, f -> f));
Map<String, Flag> longKeyToFlag = command.getFlags().stream().collect(Collectors.toMap(Flag::longFormName, f -> f));

Matcher matcher = Patterns.FLAGS.matcher(argstr);
while (matcher.find()) {
String flagname = matcher.group(2);
List<Flag> foundFlags;
if (matcher.group().startsWith("--")) {
foundFlags = Collections.singletonList(longKeyToFlag.get(flagname));
} else if (matcher.group().startsWith("-")) {
foundFlags = Lists.newArrayList(flagname.chars().mapToObj(i -> keyToFlag.get((char) i)).toArray(Flag[]::new));
} else {
continue;
return command.requirements().matches(ctx).flatMap(bool -> {
if (!bool) {
return evt.getMessage().getChannel()
.flatMap(c -> c.createMessage("You do not have permission to use this command!"))
.delayElement(Duration.ofSeconds(5))
.flatMap(m -> m.delete())
.thenReturn(command);
}
if (foundFlags.contains(null)) {
return ctx.reply("Unknown flag(s) \"" + flagname + "\".").thenReturn(command);
}

String toreplace = matcher.group(1) + matcher.group(2);
String argstr = Strings.nullToEmpty(argstrIn);

Map<Flag, String> flags = new HashMap<>();
Map<Argument<?>, String> args = new HashMap<>();

for (int i = 0; i < foundFlags.size(); i++) {
Flag flag = foundFlags.get(i);
String value = null;
if (i == foundFlags.size() - 1) {
if (flag.canHaveValue()) {
value = matcher.group(3);
if (value == null) {
value = matcher.group(4);
Map<Character, Flag> keyToFlag = command.getFlags().stream().collect(Collectors.toMap(Flag::name, f -> f));
Map<String, Flag> longKeyToFlag = command.getFlags().stream().collect(Collectors.toMap(Flag::longFormName, f -> f));

Matcher matcher = Patterns.FLAGS.matcher(argstr);
while (matcher.find()) {
String flagname = matcher.group(2);
List<Flag> foundFlags;
if (matcher.group().startsWith("--")) {
foundFlags = Collections.singletonList(longKeyToFlag.get(flagname));
} else if (matcher.group().startsWith("-")) {
foundFlags = Lists.newArrayList(flagname.chars().mapToObj(i -> keyToFlag.get((char) i)).toArray(Flag[]::new));
} else {
continue;
}
if (foundFlags.contains(null)) {
return ctx.reply("Unknown flag(s) \"" + flagname + "\".").thenReturn(command);
}

String toreplace = matcher.group(1) + matcher.group(2);

for (int i = 0; i < foundFlags.size(); i++) {
Flag flag = foundFlags.get(i);
String value = null;
if (i == foundFlags.size() - 1) {
if (flag.canHaveValue()) {
value = matcher.group(3);
if (value == null) {
value = matcher.group(4);
}
toreplace = matcher.group();
}
toreplace = matcher.group();
}
if (value == null && flag.needsValue()) {
return ctx.reply("Flag \"" + flag.longFormName() + "\" requires a value.").thenReturn(command);
}

flags.put(flag, value == null ? flag.getDefaultValue() : value);
}
if (value == null && flag.needsValue()) {
return ctx.reply("Flag \"" + flag.longFormName() + "\" requires a value.").thenReturn(command);
toreplace = Pattern.quote(toreplace) + "\\s*";
argstr = argstr.replaceFirst(toreplace, "").trim();
matcher.reset(argstr);
}

for (Argument<?> arg : command.getArguments()) {
boolean required = arg.required(flags.keySet());
if (required && argstr.isEmpty()) {
long count = command.getArguments().stream().filter(a -> a.required(flags.keySet())).count();
return ctx.reply("This command requires at least " + count + " argument" + (count > 1 ? "s" : "") + ".").thenReturn(command);
}

flags.put(flag, value == null ? flag.getDefaultValue() : value);
}
toreplace = Pattern.quote(toreplace) + "\\s*";
argstr = argstr.replaceFirst(toreplace, "").trim();
matcher.reset(argstr);
}
matcher = arg.pattern().matcher(argstr);

for (Argument<?> arg : command.getArguments()) {
boolean required = arg.required(flags.keySet());
if (required && argstr.isEmpty()) {
long count = command.getArguments().stream().filter(a -> a.required(flags.keySet())).count();
return ctx.reply("This command requires at least " + count + " argument" + (count > 1 ? "s" : "") + ".").thenReturn(command);
}

matcher = arg.pattern().matcher(argstr);

if (matcher.find()) {
String match = matcher.group();
argstr = argstr.replaceFirst(Pattern.quote(match) + "\\s*", "").trim();
args.put(arg, match);
} else if (required) {
return ctx.reply("Argument " + arg.name() + " does not accept input: " + argstr + " (does not match `" + arg.pattern().pattern() + "`)").thenReturn(command);
if (matcher.find()) {
String match = matcher.group();
argstr = argstr.replaceFirst(Pattern.quote(match) + "\\s*", "").trim();
args.put(arg, match);
} else if (required) {
return ctx.reply("Argument " + arg.name() + " does not accept input: " + argstr + " (does not match `" + arg.pattern().pattern() + "`)").thenReturn(command);
}
}
}

try {
final Mono<?> commandResult = command.process(ctx.withFlags(flags).withArgs(args))
.doOnError(t -> log.error("Exception invoking command: ", t))
.onErrorResume(CommandException.class, t -> ctx.reply("Could not process command: " + t).then(Mono.empty()))
.onErrorResume(ClientException.class, t -> ctx.reply("Discord error processing command: " + t.getStatus() + " - " + t.getErrorResponse().map(e -> e.getFields().toString()).orElse("{}")).then(Mono.empty()))
.onErrorResume(t -> ctx.reply("Unexpected error processing command: " + t).then(Mono.empty()));
return evt.getMessage().getChannel() // Automatic typing indicator
.flatMap(c -> c.typeUntil(commandResult).then())
.thenReturn(command);
} catch (RuntimeException e) {
log.error("Exception invoking command: ", e);
return ctx.reply("Unexpected error processing command: " + e).thenReturn(command); // TODO should this be different?
}
try {
final Mono<?> commandResult = command.process(ctx.withFlags(flags).withArgs(args))
.doOnError(t -> log.error("Exception invoking command: ", t))
.onErrorResume(CommandException.class, t -> ctx.reply("Could not process command: " + t).then(Mono.empty()))
.onErrorResume(ClientException.class, t -> ctx.reply("Discord error processing command: " + t.getStatus() + " - " + t.getErrorResponse().map(e -> e.getFields().toString()).orElse("{}")).then(Mono.empty()))
.onErrorResume(t -> ctx.reply("Unexpected error processing command: " + t).then(Mono.empty()));
return evt.getMessage().getChannel() // Automatic typing indicator
.flatMap(c -> c.typeUntil(commandResult).then())
.thenReturn(command);
} catch (RuntimeException e) {
log.error("Exception invoking command: ", e);
return ctx.reply("Unexpected error processing command: " + e).thenReturn(command); // TODO should this be different?
}
});
}

public boolean isAdmin(User user) {
Expand All @@ -173,13 +172,28 @@ public Optional<ICommand> findCommand(CommandContext ctx, String name) {
return findCommand(ctx.getGuildId().orElse(null), name);
}

public Optional<ICommand> findCommand(CommandContext ctx, String name, boolean useChannel) {
if (useChannel) {
return findCommand(ctx.getGuildId().orElse(null), ctx.getChannelId(), name);
} else {
return findCommand(ctx.getGuildId().orElse(null), name);
}
}

public Optional<ICommand> findCommand(@Nullable Snowflake guild, String name) {
if (guild != null && ctrl.getData(guild).getCommandBlacklist().contains(name)) {
return Optional.empty();
}
return Optional.ofNullable(commands.get(name));
}

public Optional<ICommand> findCommand(@Nullable Snowflake guild, @Nullable Snowflake channel, String name) {
if (channel != null && ctrl.getData(guild, channel).getCommandBlacklist().contains(name)) {
return Optional.empty();
}
return findCommand(guild, name);
}

public void slurpCommands() {
if (!finishedDefaultSlurp) {
slurpCommands("com.tterrag.k9.commands");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public enum IncrementListener {

private static final SaveHelper<Map<String, Long>> saveHelper = new SaveHelper<>(new File("counts"), new Gson(), new HashMap<>());
private static final GuildStorage<Map<String, Long>> counts = new GuildStorage<>(
id -> saveHelper.fromJson(id + ".json", new TypeToken<Map<String, Long>>(){})
id -> saveHelper.fromJson(id.getLeft() + ".json", new TypeToken<Map<String, Long>>(){})
);

public Mono<MessageCreateEvent> onMessage(MessageCreateEvent event) {
Expand Down
Loading