Skip to content

Commit

Permalink
Reduce ram usage for smp (#492)
Browse files Browse the repository at this point in the history
Bench: 6429947
  • Loading branch information
bftjoe authored Dec 15, 2024
1 parent 073910a commit 0101ec3
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 39 deletions.
8 changes: 2 additions & 6 deletions src/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
#include <windows.h>
#endif

Bitboard FileBBMask[8];
Bitboard RankBBMask[8];

// pawn attacks table [side][square]
Bitboard pawn_attacks[2][64];

Expand Down Expand Up @@ -222,15 +219,14 @@ void InitNewGame(ThreadData* td) {
Position* pos = &td->pos;
SearchData* sd = &td->sd;
SearchInfo* info = &td->info;
PvTable* pvTable = &td->pvTable;

CleanHistories(sd);

// Clean the PV Table
for (int index = 0; index < MAXDEPTH + 1; ++index) {
pvTable->pvLength[index] = 0;
pvTable.pvLength[index] = 0;
for (int index2 = 0; index2 < MAXDEPTH + 1; ++index2) {
pvTable->pvArray[index][index2] = NOMOVE;
pvTable.pvArray[index][index2] = NOMOVE;
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ void PrintUciOutput(const int score, const int depth, const ThreadData* td, cons
" nps " << nps << " hashfull "<< GetHashfull() << " time " << GetTimeMs() - td->info.starttime << " pv ";

// loop over the moves within a PV line
for (int count = 0; count < std::max(td->pvTable.pvLength[0], 1); count++) {
for (int count = 0; count < std::max(pvTable.pvLength[0], 1); count++) {
// print PV move
PrintMove(td->pvTable.pvArray[0][count]);
PrintMove(pvTable.pvArray[0][count]);
std::cout << " ";
}

Expand Down Expand Up @@ -270,9 +270,9 @@ void PrintUciOutput(const int score, const int depth, const ThreadData* td, cons
std::cout << std::setw(7) << std::right << std::fixed << static_cast<int>(nps / 1000.0) << "Kn/s" << " ";

// loop over the moves within a PV line
for (int count = 0; count < std::max(td->pvTable.pvLength[0], 1); count++) {
for (int count = 0; count < std::max(pvTable.pvLength[0], 1); count++) {
// print PV move
PrintMove(td->pvTable.pvArray[0][count]);
PrintMove(pvTable.pvArray[0][count]);
std::cout << " ";
}

Expand Down
48 changes: 25 additions & 23 deletions src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,22 @@ bool IsDraw(Position* pos) {
void ClearForSearch(ThreadData* td) {
// Extract data structures from ThreadData
SearchInfo* info = &td->info;
PvTable* pvTable = &td->pvTable;

// Clean the Pv array
std::memset(pvTable, 0, sizeof(td->pvTable));
// Clean the node table
std::memset(td->nodeSpentTable, 0, sizeof(td->nodeSpentTable));
// Reset plies and search info
info->starttime = GetTimeMs();
info->nodes = 0;
info->seldepth = 0;
// Main thread only unpauses any eventual search thread
if (td->id == 0)

// Main thread clears pvTable, nodeSpentTable, and unpauses any eventual search thread
if (td->id == 0) {
// Clean the Pv array
std::memset(&pvTable, 0, sizeof(pvTable));
// Clean the node table
std::memset(nodeSpentTable, 0, sizeof(nodeSpentTable));

for (auto& helper_thread : threads_data)
helper_thread.info.stopped = false;
}
}

// returns a bitboard of all the attacks to a specific square
Expand Down Expand Up @@ -194,8 +196,8 @@ bool SEE(const Position* pos, const int move, const int threshold) {
return side != Color[attacker];
}

Move GetBestMove(const PvTable* pvTable) {
return pvTable->pvArray[0][0];
Move GetBestMove() {
return pvTable.pvArray[0][0];
}

// Starts the search process, this is ideally the point where you can start a multithreaded search
Expand All @@ -222,7 +224,7 @@ void RootSearch(int depth, ThreadData* td, UciOptions* options) {
StopHelperThreads();
// Print final bestmove found
std::cout << "bestmove ";
PrintMove(GetBestMove(&td->pvTable));
PrintMove(GetBestMove());
std::cout << std::endl;
}

Expand Down Expand Up @@ -254,12 +256,12 @@ void SearchPosition(int startDepth, int finalDepth, ThreadData* td, UciOptions*
// Only the main thread handles time related tasks
if (td->id == 0) {
// Keep track of how many times in a row the best move stayed the same
if (GetBestMove(&td->pvTable) == previousBestMove) {
if (GetBestMove() == previousBestMove) {
bestMoveStabilityFactor = std::min(bestMoveStabilityFactor + 1, 4);
}
else {
bestMoveStabilityFactor = 0;
previousBestMove = GetBestMove(&td->pvTable);
previousBestMove = GetBestMove();
}

// Keep track of eval stability
Expand Down Expand Up @@ -371,7 +373,7 @@ int Negamax(int alpha, int beta, int depth, const bool cutNode, ThreadData* td,
Position* pos = &td->pos;
SearchData* sd = &td->sd;
SearchInfo* info = &td->info;
PvTable* pvTable = &td->pvTable;
const bool mainT = td->id == 0;

// Initialize the node
const bool inCheck = pos->getCheckers();
Expand All @@ -384,7 +386,8 @@ int Negamax(int alpha, int beta, int depth, const bool cutNode, ThreadData* td,
const Move excludedMove = ss->excludedMove;

// if we are in a singular search and reusing the same ss entry, we have to guard this statement otherwise the pv length will get reset
pvTable->pvLength[ss->ply] = ss->ply;
if (mainT)
pvTable.pvLength[ss->ply] = ss->ply;

// Check for the highest depth reached in search to report it to the cli
if (ss->ply > info->seldepth)
Expand Down Expand Up @@ -414,7 +417,7 @@ int Negamax(int alpha, int beta, int depth, const bool cutNode, ThreadData* td,
return Quiescence<pvNode>(alpha, beta, td, ss);

// check if more than Maxtime passed and we have to stop
if (td->id == 0 && TimeOver(&td->info)) {
if (mainT && TimeOver(&td->info)) {
StopHelperThreads();
td->info.stopped = true;
return 0;
Expand Down Expand Up @@ -797,9 +800,8 @@ int Negamax(int alpha, int beta, int depth, const bool cutNode, ThreadData* td,

// take move back
UnmakeMove(move, pos);
if ( td->id == 0
&& rootNode)
td->nodeSpentTable[FromTo(move)] += info->nodes - nodesBeforeSearch;
if (mainT && rootNode)
nodeSpentTable[FromTo(move)] += info->nodes - nodesBeforeSearch;

if (info->stopped)
return 0;
Expand All @@ -813,13 +815,13 @@ int Negamax(int alpha, int beta, int depth, const bool cutNode, ThreadData* td,
if (score > alpha) {
bestMove = move;

if (pvNode) {
if (pvNode && mainT) {
// Update the pv table
pvTable->pvArray[ss->ply][ss->ply] = move;
for (int nextPly = ss->ply + 1; nextPly < pvTable->pvLength[ss->ply + 1]; nextPly++) {
pvTable->pvArray[ss->ply][nextPly] = pvTable->pvArray[ss->ply + 1][nextPly];
pvTable.pvArray[ss->ply][ss->ply] = move;
for (int nextPly = ss->ply + 1; nextPly < pvTable.pvLength[ss->ply + 1]; nextPly++) {
pvTable.pvArray[ss->ply][nextPly] = pvTable.pvArray[ss->ply + 1][nextPly];
}
pvTable->pvLength[ss->ply] = pvTable->pvLength[ss->ply + 1];
pvTable.pvLength[ss->ply] = pvTable.pvLength[ss->ply + 1];
}

if (score >= beta) {
Expand Down
9 changes: 5 additions & 4 deletions src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ struct PvTable {
Move pvArray[MAXDEPTH + 1][MAXDEPTH + 1];
};

// These 2 tables need to be cleaned after each search. We initialize (and subsequently clean them) elsewhere
inline PvTable pvTable;
inline uint64_t nodeSpentTable[64 * 64];

// a collection of all the data a thread needs to conduct a search
struct ThreadData {
int id = 0;
Position pos;
SearchData sd;
SearchInfo info;
// Since this 2 tables need to be cleaned after each search we just initialize (and subsequently clean them) elsewhere
PvTable pvTable;
uint64_t nodeSpentTable[64 * 64];
int RootDepth;
int nmpPlies;
};
Expand All @@ -67,7 +68,7 @@ template <bool pvNode>
[[nodiscard]] int Quiescence(int alpha, int beta, ThreadData* td, SearchStack* ss);

// Gets best move from PV table
[[nodiscard]] Move GetBestMove(const PvTable* pvTable);
[[nodiscard]] Move GetBestMove();

// inspired by the Weiss engine
[[nodiscard]] bool SEE(const Position* pos, const int move, const int threshold);
Expand Down
4 changes: 2 additions & 2 deletions src/time_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ bool StopEarly(const SearchInfo* info) {
void ScaleTm(ThreadData* td, const int bestMoveStabilityFactor, const int evalStabilityFactor) {
constexpr double bestmoveScale[5] = {2.43, 1.35, 1.09, 0.88, 0.68};
constexpr double evalScale[5] = {1.25, 1.15, 1.00, 0.94, 0.88};
const int bestmove = GetBestMove(&td->pvTable);
const int bestmove = GetBestMove();
// Calculate how many nodes were spent on checking the best move
const double bestMoveNodesFraction = static_cast<double>(td->nodeSpentTable[FromTo(bestmove)]) / static_cast<double>(td->info.nodes);
const double bestMoveNodesFraction = static_cast<double>(nodeSpentTable[FromTo(bestmove)]) / static_cast<double>(td->info.nodes);
const double nodeScalingFactor = (1.52 - bestMoveNodesFraction) * 1.74;
const double bestMoveScalingFactor = bestmoveScale[bestMoveStabilityFactor];
const double evalScalingFactor = evalScale[evalStabilityFactor];
Expand Down

0 comments on commit 0101ec3

Please sign in to comment.