Skip to content

Commit

Permalink
Added random selection policy
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanKerstjens committed Dec 16, 2023
1 parent b082a6c commit a54c4d3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
9 changes: 6 additions & 3 deletions source/MoleculeAutoCorrect.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,17 +618,18 @@ struct Constant {
}; // ! MoleculeAutoCorrect::Policy::Objective namespace


// Virtually all policies can be expressed as a GreedyPolicy, provided that we
// adjust the objective on which the GreedyPolicy selects and tries to maximize.
enum class Type {
BFS,
Familiarity,
DistanceNormalizedFamiliarity,
Astar,
UCT,
MLR
MLR,
Random
};

// Virtually all policies can be expressed as a GreedyPolicy, provided that we
// adjust the objective on which the GreedyPolicy selects and tries to maximize.
struct BFS : GreedyPolicy<Vertex> {
BFS(const RDKit::ROMol& root_molecule) :
GreedyPolicy<Vertex>(Objective::TopologicalSimilarity(root_molecule)) {};
Expand Down Expand Up @@ -675,6 +676,8 @@ struct ObjectivePreservation : GreedyPolicy<Vertex> {
Objective::Wrapper(Objective::MoleculeObjective(objective))) {};
};

typedef RandomPolicy<Vertex> Random;

struct Dummy : GreedyPolicy<Vertex> {
Dummy(double x) : GreedyPolicy<Vertex>(Objective::Constant(x)) {};
};
Expand Down
37 changes: 37 additions & 0 deletions source/TreeSearch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <boost/graph/adjacency_list.hpp>
#include <boost/dynamic_bitset.hpp>
#include <concepts>
#include <random>
#include <queue>
#include <cmath>

Expand Down Expand Up @@ -377,5 +378,41 @@ class UpperConfidenceTree {
};
};

template <class Vertex>
class RandomPolicy {
typedef TreeSearch<Vertex>::Tree Tree;
std::mt19937 prng;

public:
RandomPolicy() {
std::random_device rd;
prng.seed(rd());
};

RandomPolicy(std::mt19937::result_type seed) {
prng.seed(seed);
};

std::pair<typename Tree::vertex_descriptor, bool> operator()(
const TreeSearch<Vertex>& tree_search) {
const Tree& tree = tree_search.GetTree();
const auto& vertex_expandable = tree_search.GetExpandableVerticesMask();
std::size_t n_vertices = boost::num_vertices(tree);
if (n_vertices < 2) {
return {0, n_vertices ? vertex_expandable[0] : false};
};
std::uniform_int_distribution<std::size_t> distribution (0, n_vertices - 1);
std::size_t vertex_idx = distribution(prng);
for (std::size_t i = 0; i < n_vertices; ++i) {
if (vertex_expandable[vertex_idx]) {
return {vertex_idx, true};
};
if (++vertex_idx >= n_vertices) {
vertex_idx = 0;
};
};
return {0, false};
};
};

#endif // !_TREE_SEARCH_HPP_

0 comments on commit a54c4d3

Please sign in to comment.