Skip to content

Commit

Permalink
Small bugfix and remove some unnecessary includes
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Oct 10, 2024
1 parent 3579ee4 commit 27a0cfe
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 36 deletions.
1 change: 0 additions & 1 deletion lib/kernels/include/kernels/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "op-attrs/datatype.h"
#include "utils/exception.h"
#include "utils/required.h"
#include "utils/variant.h"

namespace FlexFlow {

Expand Down
1 change: 0 additions & 1 deletion lib/kernels/include/kernels/initializer_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "accessor.h"
#include "kernels/cpu.h"
#include "op-attrs/datatype_value.dtg.h"
#include "utils/variant.h"

namespace FlexFlow {

Expand Down
2 changes: 1 addition & 1 deletion lib/pcg/include/pcg/optimizer_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "pcg/optimizers/adam_optimizer_attrs.h"
#include "pcg/optimizers/sgd_optimizer_attrs.h"
#include "utils/variant.h"
#include <variant>

namespace FlexFlow {

Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/cow_ptr_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_COW_PTR_T_H

#include "utils/type_traits.h"
#include "utils/unique.h"
#include "utils/variant.h"
#include <memory>
#include <type_traits>

Expand Down
18 changes: 18 additions & 0 deletions lib/utils/include/utils/rapidcheck/variant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H

#include <variant>
#include <rapidcheck.h>

namespace rc {

template <typename... Ts>
struct Arbitrary<std::variant<Ts...>> {
static Gen<std::variant<Ts...>> arbitrary() {
return gen::oneOf(gen::construct<std::variant<Ts...>>(gen::arbitrary<Ts>())...);
}
};

} // namespace rc

#endif
13 changes: 0 additions & 13 deletions lib/utils/include/utils/unique.h

This file was deleted.

11 changes: 0 additions & 11 deletions lib/utils/include/utils/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,4 @@ std::optional<VariantOut> cast(VariantIn const &v) {

} // namespace FlexFlow

namespace rc {

template <typename... Ts>
struct Arbitrary<std::variant<Ts...>> {
static Gen<std::variant<Ts...>> arbitrary() {
return gen::oneOf(gen::cast<std::variant<Ts...>>(gen::arbitrary<Ts>())...);
}
};

} // namespace rc

#endif
11 changes: 11 additions & 0 deletions lib/utils/src/utils/rapidcheck/variant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "utils/rapidcheck/variant.h"

namespace rc {

using T0 = int;
using T1 = std::string;

template
struct Arbitrary<std::variant<T0, T1>>;

} // namespace rc
13 changes: 13 additions & 0 deletions lib/utils/test/src/utils/rapidcheck/variant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "utils/rapidcheck/variant.h"
#include <doctest/doctest.h>
#include "test/utils/rapidcheck.h"

using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("Arbitrary<std::variant>") {
RC_SUBCASE("valid type", [](std::variant<int, float> v) {
return std::holds_alternative<int>(v) || std::holds_alternative<float>(v);
});
}
}
7 changes: 0 additions & 7 deletions lib/utils/test/src/utils/variant.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "utils/variant.h"
#include "test/utils/doctest/fmt/optional.h"
#include "test/utils/doctest/fmt/variant.h"
#include "test/utils/rapidcheck.h"
#include <doctest/doctest.h>

using namespace ::FlexFlow;
Expand Down Expand Up @@ -74,10 +73,4 @@ TEST_SUITE(FF_TEST_SUITE) {
// Check the result
CHECK(get<int>(wider_variant) == 42);
}

TEST_CASE("Arbitrary<std::variant>") {
RC_SUBCASE("valid type", [](std::variant<int, float> v) {
return std::holds_alternative<int>(v) || std::holds_alternative<float>(v);
});
}
}

0 comments on commit 27a0cfe

Please sign in to comment.