diff --git a/albatross/cereal/traits.h b/albatross/cereal/traits.h index ed86730f..7c79f6c5 100644 --- a/albatross/cereal/traits.h +++ b/albatross/cereal/traits.h @@ -18,7 +18,7 @@ namespace albatross { /* - * This little trick was borrowed from cereal, you an think of it as + * This little trick was borrowed from cereal, you can think of it as * a function that will always return false ... but that doesn't * get resolved until template instantiation, which when combined * with a static assert let's you include a static assert that diff --git a/tests/test_traits_core.cc b/tests/test_traits_core.cc index 1b86d758..a9330c7a 100644 --- a/tests/test_traits_core.cc +++ b/tests/test_traits_core.cc @@ -16,6 +16,11 @@ namespace albatross { +TEST(test_traits_core, test_is_vector) { + EXPECT_TRUE(bool(is_vector>::value)); + EXPECT_FALSE(bool(is_vector::value)); +} + struct X {}; struct Y {}; struct Z {}; @@ -23,7 +28,7 @@ struct Z {}; class HasValidFitImpl : public ModelBase { public: Fit fit(const std::vector &, - const MarginalDistribution &) const { + const MarginalDistribution &) const { return {}; }; }; @@ -46,17 +51,17 @@ class HasNonConstFitImpl : public ModelBase { class HasNonConstArgsFitImpl : public ModelBase { public: Fit fit(std::vector &, - const MarginalDistribution &) const { + const MarginalDistribution &) const { return {}; }; Fit fit(const std::vector &, - MarginalDistribution &) const { + MarginalDistribution &) const { return {}; }; Fit fit(std::vector &, - MarginalDistribution &) const { + MarginalDistribution &) const { return {}; }; }; @@ -64,15 +69,15 @@ class HasNonConstArgsFitImpl : public ModelBase { class HasProtectedValidFitImpl : public ModelBase { protected: Fit fit(const std::vector &, - const MarginalDistribution &) const { + const MarginalDistribution &) const { return {}; }; }; -class HasPrivateValidFitImpl : public ModelBase { +class HasPrivateValidFitImpl : public ModelBase { private: Fit fit(const std::vector &, - const MarginalDistribution &) const { + const MarginalDistribution &) const { return {}; }; }; @@ -98,7 +103,7 @@ class HasValidXYFitImpl : public ModelBase { }; Fit fit(const std::vector &, - const MarginalDistribution &) const { + const MarginalDistribution &) const { return {}; }; }; @@ -140,18 +145,25 @@ TEST(test_traits_core, test_fit_type) { fit_type::type>::value)); EXPECT_TRUE(bool(std::is_same, fit_type::type>::value)); + EXPECT_TRUE( + bool(std::is_same, + fit_type::type>::value)); + // EXPECT_FALSE( + // bool(std::is_same, + // fit_type::type>::value)); } -template -struct Base {}; +template struct Base {}; struct Derived : public Base {}; TEST(test_traits_core, test_is_valid_fit_type) { EXPECT_TRUE(bool(is_valid_fit_type>::value)); EXPECT_TRUE(bool(is_valid_fit_type>::value)); - EXPECT_TRUE(bool(is_valid_fit_type, Fit, X>>::value)); - EXPECT_TRUE(bool(is_valid_fit_type, Fit, Y>>::value)); + EXPECT_TRUE( + bool(is_valid_fit_type, Fit, X>>::value)); + EXPECT_TRUE( + bool(is_valid_fit_type, Fit, Y>>::value)); // If a Derived class which inherits from Base has a fit // which returns Fit> consider that a valid fit type. EXPECT_TRUE(bool(is_valid_fit_type, X>>::value)); @@ -161,8 +173,7 @@ TEST(test_traits_core, test_is_valid_fit_type) { EXPECT_FALSE(bool(is_valid_fit_type, Fit>::value)); } -template -struct Adaptable; +template struct Adaptable; template struct Fit, FeatureType> {}; @@ -171,7 +182,7 @@ template struct Adaptable : public ModelBase> { Fit, X> fit(const std::vector &, - const MarginalDistribution &) const { + const MarginalDistribution &) const { return Fit, X>(); } @@ -180,11 +191,10 @@ struct Adaptable : public ModelBase> { * to this class so it can be picked up by ModelBase. */ template ::value, - int>::type = 0> + typename std::enable_if::value, + int>::type = 0> auto fit(const std::vector &features, - const MarginalDistribution &targets) const { + const MarginalDistribution &targets) const { return impl().fit(features, targets); } @@ -193,33 +203,33 @@ struct Adaptable : public ModelBase> { */ ImplType &impl() { return *static_cast(this); } const ImplType &impl() const { return *static_cast(this); } - }; struct Extended : public Adaptable { using Base = Adaptable; - auto fit(const std::vector &, - const MarginalDistribution &targets) const { + auto fit(const std::vector &, const MarginalDistribution &targets) const { std::vector xs = {{}}; return Base::fit(xs, targets); } - Z predict(const std::vector &, - const Fit, X> &, + Z predict(const std::vector &, const Fit, X> &, PredictTypeIdentity) const { return {}; } +}; +struct OtherExtended : public Adaptable { }; TEST(test_traits_core, test_adaptable_fit_type) { EXPECT_TRUE(bool(std::is_base_of, X>, - fit_type::type>::value)); + fit_type::type>::value)); EXPECT_TRUE(bool(std::is_base_of, X>, - fit_type::type>::value)); - EXPECT_TRUE(bool(has_valid_predict, X>, Z>::value)); + fit_type::type>::value)); + EXPECT_TRUE(bool( + has_valid_predict, X>, Z>::value)); } TEST(test_traits_core, test_adaptable_has_valid_fit) { @@ -227,8 +237,14 @@ TEST(test_traits_core, test_adaptable_has_valid_fit) { EXPECT_TRUE(bool(has_valid_fit::value)); EXPECT_TRUE(bool(is_valid_fit_type>::value)); EXPECT_TRUE(bool(is_valid_fit_type>::value)); - EXPECT_TRUE(bool(is_valid_fit_type, X>>::value)); - EXPECT_TRUE(bool(is_valid_fit_type, Y>>::value)); + EXPECT_TRUE( + bool(is_valid_fit_type, X>>::value)); + EXPECT_TRUE( + bool(is_valid_fit_type, Y>>::value)); + EXPECT_FALSE( + bool(is_valid_fit_type, Y>>::value)); + EXPECT_FALSE( + bool(is_valid_fit_type, Y>>::value)); } /* @@ -238,7 +254,7 @@ class HasMeanPredictImpl { public: Eigen::VectorXd predict(const std::vector &, const Fit &, - PredictTypeIdentity) const { + PredictTypeIdentity) const { return Eigen::VectorXd::Zero(0); } }; @@ -246,9 +262,8 @@ class HasMeanPredictImpl { class HasMarginalPredictImpl { public: MarginalDistribution - predict(const std::vector &, - const Fit &, - PredictTypeIdentity) const { + predict(const std::vector &, const Fit &, + PredictTypeIdentity) const { const auto mean = Eigen::VectorXd::Zero(0); return MarginalDistribution(mean); } @@ -257,8 +272,8 @@ class HasMarginalPredictImpl { class HasJointPredictImpl { public: JointDistribution predict(const std::vector &, - const Fit &, - PredictTypeIdentity) const { + const Fit &, + PredictTypeIdentity) const { const auto mean = Eigen::VectorXd::Zero(0); return JointDistribution(mean); } @@ -267,22 +282,21 @@ class HasJointPredictImpl { class HasAllPredictImpls { public: Eigen::VectorXd predict(const std::vector &, - const Fit &, - PredictTypeIdentity) const { + const Fit &, + PredictTypeIdentity) const { return Eigen::VectorXd::Zero(0); } MarginalDistribution - predict(const std::vector &, - const Fit &, - PredictTypeIdentity) const { + predict(const std::vector &, const Fit &, + PredictTypeIdentity) const { const auto mean = Eigen::VectorXd::Zero(0); return MarginalDistribution(mean); } JointDistribution predict(const std::vector &, - const Fit &, - PredictTypeIdentity) const { + const Fit &, + PredictTypeIdentity) const { const auto mean = Eigen::VectorXd::Zero(0); return JointDistribution(mean); } @@ -290,20 +304,36 @@ class HasAllPredictImpls { TEST(test_traits_core, test_has_valid_predict_impl) { + EXPECT_TRUE(bool(has_valid_predict_mean>::value)); + EXPECT_FALSE(bool(has_valid_predict_marginal>::value)); + EXPECT_FALSE(bool(has_valid_predict_joint>::value)); - EXPECT_TRUE(bool(has_valid_predict_mean>::value)); EXPECT_TRUE( - bool(has_valid_predict_marginal>::value)); - EXPECT_TRUE(bool(has_valid_predict_joint>::value)); - EXPECT_TRUE(bool(has_valid_predict_mean>::value)); - EXPECT_TRUE(bool(has_valid_predict_marginal>::value)); - EXPECT_TRUE(bool(has_valid_predict_joint>::value)); + bool(has_valid_predict_marginal>::value)); + EXPECT_FALSE(bool(has_valid_predict_mean>::value)); + EXPECT_FALSE(bool(has_valid_predict_joint>::value)); + + + EXPECT_TRUE(bool(has_valid_predict_joint>::value)); + EXPECT_FALSE(bool(has_valid_predict_mean>::value)); + EXPECT_FALSE(bool(has_valid_predict_marginal>::value)); + + + EXPECT_TRUE(bool(has_valid_predict_mean>::value)); + EXPECT_TRUE(bool(has_valid_predict_marginal>::value)); + EXPECT_TRUE(bool(has_valid_predict_joint>::value)); } class HasName { public: - std::string name() const {return "name";}; + std::string name() const { return "name"; }; }; class HasNoName {}; @@ -313,5 +343,4 @@ TEST(test_traits_covariance, test_has_name) { EXPECT_FALSE(bool(has_name::value)); } - } // namespace albatross