Skip to content

Commit

Permalink
AutoDA: Kill the program if it's likely that the user set the wrong l…
Browse files Browse the repository at this point in the history
…anguage.
  • Loading branch information
Mysticial committed Jan 31, 2025
1 parent e0af2bf commit a60032b
Show file tree
Hide file tree
Showing 21 changed files with 216 additions and 35 deletions.
1 change: 1 addition & 0 deletions SerialPrograms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ file(GLOB MAIN_SOURCES
Source/CommonTools/Audio/SpectrogramMatcher.cpp
Source/CommonTools/Audio/SpectrogramMatcher.h
Source/CommonTools/DetectionDebouncer.h
Source/CommonTools/FailureWatchdog.h
Source/CommonTools/ImageMatch/CroppedImageDictionaryMatcher.cpp
Source/CommonTools/ImageMatch/CroppedImageDictionaryMatcher.h
Source/CommonTools/ImageMatch/ExactImageDictionaryMatcher.cpp
Expand Down
1 change: 1 addition & 0 deletions SerialPrograms/SerialPrograms.pro
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,7 @@ HEADERS += \
Source/CommonTools/Audio/AudioTemplateCache.h \
Source/CommonTools/Audio/SpectrogramMatcher.h \
Source/CommonTools/DetectionDebouncer.h \
Source/CommonTools/FailureWatchdog.h \
Source/CommonTools/GlobalInferenceRunner.h \
Source/CommonTools/ImageMatch/CroppedImageDictionaryMatcher.h \
Source/CommonTools/ImageMatch/ExactImageDictionaryMatcher.h \
Expand Down
112 changes: 112 additions & 0 deletions SerialPrograms/Source/CommonTools/FailureWatchdog.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/* Failure Watchdog
*
* From: https://github.com/PokemonAutomation/Arduino-Source
*
*/

#ifndef PokemonAutomation_CommonTools_FailureWatchdog_H
#define PokemonAutomation_CommonTools_FailureWatchdog_H

#include "Common/Cpp/AbstractLogger.h"
#include "Common/Cpp/Time.h"
#include "Common/Cpp/Exceptions.h"

namespace PokemonAutomation{





class FailureWatchdog{
public:
FailureWatchdog(
Logger& logger,
std::string failure_message,
uint64_t min_count = 5,
double min_success_rate = 0.5,
std::chrono::seconds time_limit = std::chrono::seconds(120)
)
: m_logger(logger)
, m_failure_message(std::move(failure_message))
, m_min_count(min_count)
, m_min_success_rate(min_success_rate)
, m_time_limit(time_limit)
{
restart();
}
void restart(){
m_expiration = current_time() + m_time_limit;
m_expired = false;
m_successes = 0;
m_total = 0;
}

void push_result(bool success){
m_successes += success ? 1 : 0;
m_total++;
if (success || m_expired){
return;
}

WallClock current = current_time();
if (current >= m_expiration){
m_expired = true;
}

if (m_total < m_min_count){
return;
}

double threshold = (double)m_total * m_min_success_rate;
if ((double)m_successes >= threshold){
return;
}


throw UserSetupError(m_logger, m_failure_message);
}


private:
Logger& m_logger;
std::string m_failure_message;
uint64_t m_min_count;
double m_min_success_rate;
WallDuration m_time_limit;
WallClock m_expiration;
bool m_expired;

uint64_t m_successes;
uint64_t m_total;
};





class OcrFailureWatchdog : public FailureWatchdog{
public:
OcrFailureWatchdog(
Logger& logger,
std::string failure_message = "Too many text recognition errors. Did you set the correct language?",
uint64_t min_count = 5,
double min_success_rate = 0.5,
std::chrono::seconds time_limit = std::chrono::seconds(120)
)
: FailureWatchdog(
logger,
std::move(failure_message),
min_count,
min_success_rate,
time_limit
)
{}
};






}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ void KeyboardInputController::thread_loop(){
break;
}


// If state is neutral, just issue a stop.
if (neutral){
if (try_stop_commands()){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,6 @@ class WatchdogTest1 : public WatchdogCallback{







void TestProgramComputer::program(ProgramEnvironment& env, CancellableScope& scope){
using namespace Kernels;
using namespace NintendoSwitch;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*
*/

#include "Common/Cpp/Containers/FixedLimitVector.tpp"
#include "CommonFramework/ImageTypes/ImageViewRGB32.h"
#include "CommonFramework/Tools/ErrorDumper.h"
#include "CommonFramework/Notifications/ProgramInfo.h"
Expand Down Expand Up @@ -36,6 +37,35 @@ namespace PokemonSwSh{
namespace MaxLairInternal{



AdventureRuntime::~AdventureRuntime() = default;
AdventureRuntime::AdventureRuntime(
FixedLimitVector<ConsoleHandle>& consoles,
const size_t p_host_index,
const Consoles& p_console_settings,
const EndBattleDecider& p_actions,
const bool p_go_home_when_done,
HostingSettings& p_hosting_settings,
EventNotificationOption& p_notification_status,
EventNotificationOption& p_notification_shiny,
Stats& p_session_stats
)
: host_index(p_host_index)
, console_settings(p_console_settings)
, actions(p_actions)
, go_home_when_done(p_go_home_when_done)
, hosting_settings(p_hosting_settings)
, notification_status(p_notification_status)
, notification_shiny(p_notification_shiny)
, ocr_watchdog(p_console_settings.active_consoles())
, session_stats(p_session_stats)
{
for (size_t c = 0; c < p_console_settings.active_consoles(); c++){
ocr_watchdog.emplace_back(consoles[c].logger());
}
}


StateMachineAction run_state_iteration(
AdventureRuntime& runtime, size_t console_index,
ProgramEnvironment& env,
Expand Down Expand Up @@ -94,6 +124,7 @@ StateMachineAction run_state_iteration(
console_index,
stream, context,
global_state,
runtime.ocr_watchdog[console_index],
runtime.console_settings[console_index]
);
return StateMachineAction::KEEP_GOING;
Expand Down Expand Up @@ -124,7 +155,9 @@ StateMachineAction run_state_iteration(
stream.log("Current State: Move Select");
return run_move_select(
env, console_index,
stream, context, global_state,
stream, context,
runtime.ocr_watchdog[console_index],
global_state,
runtime.console_settings[console_index],
battle_menu.dmaxed(),
battle_menu.cheer()
Expand All @@ -136,6 +169,7 @@ StateMachineAction run_state_iteration(
env, console_index,
stream, context,
runtime.console_settings[console_index].language,
runtime.ocr_watchdog[console_index],
global_state,
decider
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
#ifndef PokemonAutomation_PokemonSwSh_MaxLair_StateMachine_H
#define PokemonAutomation_PokemonSwSh_MaxLair_StateMachine_H

#include "Common/Cpp/Concurrency/SpinLock.h"
#include "CommonFramework/Tools/VideoStream.h"
#include "CommonFramework/Tools/ProgramEnvironment.h"
#include "Common/Cpp/Concurrency/SpinLock.h"
#include "CommonTools/FailureWatchdog.h"
#include "NintendoSwitch/Controllers/NintendoSwitch_Controller.h"
#include "NintendoSwitch/NintendoSwitch_ConsoleHandle.h"
#include "PokemonSwSh/Inference/PokemonSwSh_QuantityReader.h"
#include "PokemonSwSh/MaxLair/Options/PokemonSwSh_MaxLair_Options.h"
#include "PokemonSwSh/MaxLair/Options/PokemonSwSh_MaxLair_Options_Consoles.h"
Expand Down Expand Up @@ -52,7 +54,9 @@ struct ConsoleRuntime{
};

struct AdventureRuntime{
~AdventureRuntime();
AdventureRuntime(
FixedLimitVector<ConsoleHandle>& consoles,
const size_t p_host_index,
const Consoles& p_console_settings,
const EndBattleDecider& p_actions,
Expand All @@ -61,16 +65,7 @@ struct AdventureRuntime{
EventNotificationOption& p_notification_status,
EventNotificationOption& p_notification_shiny,
Stats& p_session_stats
)
: host_index(p_host_index)
, console_settings(p_console_settings)
, actions(p_actions)
, go_home_when_done(p_go_home_when_done)
, hosting_settings(p_hosting_settings)
, notification_status(p_notification_status)
, notification_shiny(p_notification_shiny)
, session_stats(p_session_stats)
{}
);

const size_t host_index;
const Consoles& console_settings;
Expand All @@ -79,6 +74,9 @@ struct AdventureRuntime{
HostingSettings& hosting_settings;
EventNotificationOption& notification_status;
EventNotificationOption& notification_shiny;

FixedLimitVector<OcrFailureWatchdog> ocr_watchdog;

Stats& session_stats;

PathStats path_stats;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,13 @@ bool BattleMenuDetector::detect(const ImageViewRGB32& screen){



BattleMenuReader::BattleMenuReader(VideoOverlay& overlay, Language language)
BattleMenuReader::BattleMenuReader(
VideoOverlay& overlay,
Language language,
OcrFailureWatchdog& ocr_watchdog
)
: m_language(language)
, m_ocr_watchdog(ocr_watchdog)
, m_opponent_name(overlay, {0.3, 0.010, 0.4, 0.10}, COLOR_BLUE)
, m_summary_opponent_name(overlay, {0.200, 0.100, 0.300, 0.065}, COLOR_BLUE)
, m_summary_opponent_types(overlay, {0.200, 0.170, 0.300, 0.050}, COLOR_BLUE)
Expand Down Expand Up @@ -243,7 +248,7 @@ std::set<std::string> BattleMenuReader::read_opponent(
for (size_t c = 0; c < 3; c++){
screen = feed.snapshot();
ImageViewRGB32 image = extract_box_reference(screen, m_opponent_name);
result = read_pokemon_name(logger, m_language, image);
result = read_pokemon_name(logger, m_language, m_ocr_watchdog, image);
if (!result.empty()){
return result;
}
Expand Down Expand Up @@ -316,7 +321,7 @@ std::set<std::string> BattleMenuReader::read_opponent_in_summary(Logger& logger,
ImageViewRGB32 name = extract_box_reference(screen, m_summary_opponent_name);

// We can use a weaker threshold here since we are cross-checking with the type.
name_slugs = read_pokemon_name(logger, m_language, name, -1.0);
name_slugs = read_pokemon_name(logger, m_language, m_ocr_watchdog, name, -1.0);
}

// See if there's anything in common between the slugs that match the type
Expand Down Expand Up @@ -381,6 +386,7 @@ std::set<std::string> BattleMenuReader::read_opponent_in_summary(Logger& logger,
std::string BattleMenuReader::read_own_mon(Logger& logger, const ImageViewRGB32& screen) const{
return read_pokemon_name_sprite(
logger,
m_ocr_watchdog,
screen,
m_own_sprite,
m_own_name, m_language,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "CommonFramework/Language.h"
#include "CommonFramework/Logging/Logger.h"
#include "CommonFramework/VideoPipeline/VideoOverlayScopes.h"
#include "CommonTools/FailureWatchdog.h"
#include "CommonTools/InferenceCallbacks/VisualInferenceCallback.h"
#include "PokemonSwSh/MaxLair/Framework/PokemonSwSh_MaxLair_State.h"

Expand Down Expand Up @@ -56,7 +57,11 @@ class BattleMenuDetector : public VisualInferenceCallback{

class BattleMenuReader{
public:
BattleMenuReader(VideoOverlay& overlay, Language language);
BattleMenuReader(
VideoOverlay& overlay,
Language language,
OcrFailureWatchdog& ocr_watchdog
);

std::set<std::string> read_opponent(
Logger& logger, CancellableScope& scope,
Expand All @@ -74,6 +79,7 @@ class BattleMenuReader{

private:
Language m_language;
OcrFailureWatchdog& m_ocr_watchdog;
OverlayBoxScope m_opponent_name;
OverlayBoxScope m_summary_opponent_name;
OverlayBoxScope m_summary_opponent_types;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ std::string read_boss_sprite(VideoStream& stream){

std::set<std::string> read_pokemon_name(
Logger& logger, Language language,
OcrFailureWatchdog& ocr_watchdog,
const ImageViewRGB32& image,
double max_log10p
){
Expand All @@ -123,8 +124,10 @@ std::set<std::string> read_pokemon_name(
);
// result.log(logger);
if (result.results.empty()){
ocr_watchdog.push_result(false);
return {};
}
ocr_watchdog.push_result(true);

// Convert OCR slugs to MaxLair name slugs.
std::set<std::string> ret;
Expand Down Expand Up @@ -301,6 +304,7 @@ std::string read_pokemon_sprite_with_item(

std::string read_pokemon_name_sprite(
Logger& logger,
OcrFailureWatchdog& ocr_watchdog,
const ImageViewRGB32& screen,
const ImageFloatBox& sprite_box,
const ImageFloatBox& name_box, Language language,
Expand All @@ -313,7 +317,7 @@ std::string read_pokemon_name_sprite(
ImageViewRGB32 image = extract_box_reference(screen, name_box);

std::set<std::string> ocr_slugs;
for (const std::string& slug : read_pokemon_name(logger, language, image)){
for (const std::string& slug : read_pokemon_name(logger, language, ocr_watchdog, image)){
// Only include candidates that are valid rental Pokemon.
auto iter = RENTALS.find(slug);
if (iter != RENTALS.end()){
Expand Down
Loading

0 comments on commit a60032b

Please sign in to comment.