diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 5fbcd91a06..39da65c3be 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -7,7 +7,6 @@ #include "op-attrs/datatype.h" #include "utils/exception.h" #include "utils/required.h" -#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/initializer_kernels.h b/lib/kernels/include/kernels/initializer_kernels.h index 52609a303f..9840e457e6 100644 --- a/lib/kernels/include/kernels/initializer_kernels.h +++ b/lib/kernels/include/kernels/initializer_kernels.h @@ -4,7 +4,6 @@ #include "accessor.h" #include "kernels/cpu.h" #include "op-attrs/datatype_value.dtg.h" -#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h index 4bac74b999..3e787503d6 100644 --- a/lib/pcg/include/pcg/optimizer_attrs.h +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -3,7 +3,7 @@ #include "pcg/optimizers/adam_optimizer_attrs.h" #include "pcg/optimizers/sgd_optimizer_attrs.h" -#include "utils/variant.h" +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 9a655ae072..7aed437136 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -2,8 +2,6 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_COW_PTR_T_H #include "utils/type_traits.h" -#include "utils/unique.h" -#include "utils/variant.h" #include #include diff --git a/lib/utils/include/utils/rapidcheck/variant.h b/lib/utils/include/utils/rapidcheck/variant.h new file mode 100644 index 0000000000..1a295e19e1 --- /dev/null +++ b/lib/utils/include/utils/rapidcheck/variant.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H + +#include +#include + +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::oneOf(gen::construct>(gen::arbitrary())...); + } +}; + +} // namespace rc + +#endif diff --git a/lib/utils/include/utils/unique.h b/lib/utils/include/utils/unique.h deleted file mode 100644 index cf6eb39026..0000000000 --- a/lib/utils/include/utils/unique.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UNIQUE_H -#define _FLEXFLOW_UTILS_INCLUDE_UNIQUE_H - -#include - -namespace FlexFlow { -template -std::unique_ptr make_unique(Args &&...args) { - return std::unique_ptr(new T(std::forward(args)...)); -} -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index bb2286a9cd..241d631200 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -213,15 +213,4 @@ std::optional cast(VariantIn const &v) { } // namespace FlexFlow -namespace rc { - -template -struct Arbitrary> { - static Gen> arbitrary() { - return gen::oneOf(gen::cast>(gen::arbitrary())...); - } -}; - -} // namespace rc - #endif diff --git a/lib/utils/src/utils/rapidcheck/variant.cc b/lib/utils/src/utils/rapidcheck/variant.cc new file mode 100644 index 0000000000..5fbc9f6910 --- /dev/null +++ b/lib/utils/src/utils/rapidcheck/variant.cc @@ -0,0 +1,11 @@ +#include "utils/rapidcheck/variant.h" + +namespace rc { + +using T0 = int; +using T1 = std::string; + +template + struct Arbitrary>; + +} // namespace rc diff --git a/lib/utils/test/src/utils/rapidcheck/variant.cc b/lib/utils/test/src/utils/rapidcheck/variant.cc new file mode 100644 index 0000000000..e201c7af8f --- /dev/null +++ b/lib/utils/test/src/utils/rapidcheck/variant.cc @@ -0,0 +1,13 @@ +#include "utils/rapidcheck/variant.h" +#include +#include "test/utils/rapidcheck.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Arbitrary") { + RC_SUBCASE("valid type", [](std::variant v) { + return std::holds_alternative(v) || std::holds_alternative(v); + }); + } +} diff --git a/lib/utils/test/src/utils/variant.cc b/lib/utils/test/src/utils/variant.cc index 36b014b2e3..3f6feadda0 100644 --- a/lib/utils/test/src/utils/variant.cc +++ b/lib/utils/test/src/utils/variant.cc @@ -1,7 +1,6 @@ #include "utils/variant.h" #include "test/utils/doctest/fmt/optional.h" #include "test/utils/doctest/fmt/variant.h" -#include "test/utils/rapidcheck.h" #include using namespace ::FlexFlow; @@ -74,10 +73,4 @@ TEST_SUITE(FF_TEST_SUITE) { // Check the result CHECK(get(wider_variant) == 42); } - - TEST_CASE("Arbitrary") { - RC_SUBCASE("valid type", [](std::variant v) { - return std::holds_alternative(v) || std::holds_alternative(v); - }); - } }