Skip to content

Commit

Permalink
Add some fail cases for traits, fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
akleeman committed Mar 11, 2019
1 parent f599bc0 commit b157d35
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 51 deletions.
2 changes: 1 addition & 1 deletion albatross/cereal/traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
namespace albatross {

/*
* This little trick was borrowed from cereal, you an think of it as
* This little trick was borrowed from cereal, you can think of it as
* a function that will always return false ... but that doesn't
* get resolved until template instantiation, which when combined
* with a static assert let's you include a static assert that
Expand Down
129 changes: 79 additions & 50 deletions tests/test_traits_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

namespace albatross {

TEST(test_traits_core, test_is_vector) {
EXPECT_TRUE(bool(is_vector<std::vector<double>>::value));
EXPECT_FALSE(bool(is_vector<double>::value));
}

struct X {};
struct Y {};
struct Z {};

class HasValidFitImpl : public ModelBase<HasValidFitImpl> {
public:
Fit<HasValidFitImpl, X> fit(const std::vector<X> &,
const MarginalDistribution &) const {
const MarginalDistribution &) const {
return {};
};
};
Expand All @@ -46,33 +51,33 @@ class HasNonConstFitImpl : public ModelBase<HasNonConstFitImpl> {
class HasNonConstArgsFitImpl : public ModelBase<HasNonConstFitImpl> {
public:
Fit<HasNonConstArgsFitImpl, X> fit(std::vector<X> &,
const MarginalDistribution &) const {
const MarginalDistribution &) const {
return {};
};

Fit<HasNonConstArgsFitImpl, X> fit(const std::vector<X> &,
MarginalDistribution &) const {
MarginalDistribution &) const {
return {};
};

Fit<HasNonConstArgsFitImpl, X> fit(std::vector<X> &,
MarginalDistribution &) const {
MarginalDistribution &) const {
return {};
};
};

class HasProtectedValidFitImpl : public ModelBase<HasNonConstFitImpl> {
protected:
Fit<HasProtectedValidFitImpl, X> fit(const std::vector<X> &,
const MarginalDistribution &) const {
const MarginalDistribution &) const {
return {};
};
};

class HasPrivateValidFitImpl : public ModelBase<HasPrivateValidFitImpl> {
class HasPrivateValidFitImpl : public ModelBase<HasPrivateValidFitImpl> {
private:
Fit<HasPrivateValidFitImpl, X> fit(const std::vector<X> &,
const MarginalDistribution &) const {
const MarginalDistribution &) const {
return {};
};
};
Expand All @@ -98,7 +103,7 @@ class HasValidXYFitImpl : public ModelBase<HasValidXYFitImpl> {
};

Fit<HasValidXYFitImpl, Y> fit(const std::vector<Y> &,
const MarginalDistribution &) const {
const MarginalDistribution &) const {
return {};
};
};
Expand Down Expand Up @@ -140,18 +145,25 @@ TEST(test_traits_core, test_fit_type) {
fit_type<HasValidXYFitImpl, X>::type>::value));
EXPECT_TRUE(bool(std::is_same<Fit<HasValidXYFitImpl, Y>,
fit_type<HasValidXYFitImpl, Y>::type>::value));
EXPECT_TRUE(
bool(std::is_same<Fit<HasValidAndInvalidFitImpl, X>,
fit_type<HasValidAndInvalidFitImpl, X>::type>::value));
// EXPECT_FALSE(
// bool(std::is_same<Fit<HasPrivateValidFitImpl, X>,
// fit_type<HasPrivateValidFitImpl, X>::type>::value));
}

template <typename T>
struct Base {};
template <typename T> struct Base {};

struct Derived : public Base<Derived> {};

TEST(test_traits_core, test_is_valid_fit_type) {
EXPECT_TRUE(bool(is_valid_fit_type<Derived, Fit<Derived, X>>::value));
EXPECT_TRUE(bool(is_valid_fit_type<Derived, Fit<Derived, Y>>::value));
EXPECT_TRUE(bool(is_valid_fit_type<Base<Derived>, Fit<Base<Derived>, X>>::value));
EXPECT_TRUE(bool(is_valid_fit_type<Base<Derived>, Fit<Base<Derived>, Y>>::value));
EXPECT_TRUE(
bool(is_valid_fit_type<Base<Derived>, Fit<Base<Derived>, X>>::value));
EXPECT_TRUE(
bool(is_valid_fit_type<Base<Derived>, Fit<Base<Derived>, Y>>::value));
// If a Derived class which inherits from Base<Derived> has a fit
// which returns Fit<Base<Derived>> consider that a valid fit type.
EXPECT_TRUE(bool(is_valid_fit_type<Derived, Fit<Base<Derived>, X>>::value));
Expand All @@ -161,8 +173,7 @@ TEST(test_traits_core, test_is_valid_fit_type) {
EXPECT_FALSE(bool(is_valid_fit_type<Base<Derived>, Fit<Derived, Y>>::value));
}

template <typename T>
struct Adaptable;
template <typename T> struct Adaptable;

template <typename T, typename FeatureType>
struct Fit<Adaptable<T>, FeatureType> {};
Expand All @@ -171,7 +182,7 @@ template <typename ImplType>
struct Adaptable : public ModelBase<Adaptable<ImplType>> {

Fit<Adaptable<ImplType>, X> fit(const std::vector<X> &,
const MarginalDistribution &) const {
const MarginalDistribution &) const {
return Fit<Adaptable<ImplType>, X>();
}

Expand All @@ -180,11 +191,10 @@ struct Adaptable : public ModelBase<Adaptable<ImplType>> {
* to this class so it can be picked up by ModelBase.
*/
template <typename FeatureType,
typename std::enable_if<
has_valid_fit<ImplType, FeatureType>::value,
int>::type = 0>
typename std::enable_if<has_valid_fit<ImplType, FeatureType>::value,
int>::type = 0>
auto fit(const std::vector<FeatureType> &features,
const MarginalDistribution &targets) const {
const MarginalDistribution &targets) const {
return impl().fit(features, targets);
}

Expand All @@ -193,42 +203,48 @@ struct Adaptable : public ModelBase<Adaptable<ImplType>> {
*/
ImplType &impl() { return *static_cast<ImplType *>(this); }
const ImplType &impl() const { return *static_cast<const ImplType *>(this); }

};

struct Extended : public Adaptable<Extended> {

using Base = Adaptable<Extended>;

auto fit(const std::vector<Y> &,
const MarginalDistribution &targets) const {
auto fit(const std::vector<Y> &, const MarginalDistribution &targets) const {
std::vector<X> xs = {{}};
return Base::fit(xs, targets);
}

Z predict(const std::vector<Y> &,
const Fit<Adaptable<Extended>, X> &,
Z predict(const std::vector<Y> &, const Fit<Adaptable<Extended>, X> &,
PredictTypeIdentity<Z>) const {
return {};
}
};

struct OtherExtended : public Adaptable<OtherExtended> {
};

TEST(test_traits_core, test_adaptable_fit_type) {
EXPECT_TRUE(bool(std::is_base_of<Fit<Adaptable<Extended>, X>,
fit_type<Extended, Y>::type>::value));
fit_type<Extended, Y>::type>::value));
EXPECT_TRUE(bool(std::is_base_of<Fit<Adaptable<Extended>, X>,
fit_type<Extended, X>::type>::value));
EXPECT_TRUE(bool(has_valid_predict<Extended, Y, Fit<Adaptable<Extended>, X>, Z>::value));
fit_type<Extended, X>::type>::value));
EXPECT_TRUE(bool(
has_valid_predict<Extended, Y, Fit<Adaptable<Extended>, X>, Z>::value));
}

TEST(test_traits_core, test_adaptable_has_valid_fit) {
EXPECT_FALSE(bool(has_valid_fit<Extended, X>::value));
EXPECT_TRUE(bool(has_valid_fit<Extended, Y>::value));
EXPECT_TRUE(bool(is_valid_fit_type<Extended, Fit<Extended, X>>::value));
EXPECT_TRUE(bool(is_valid_fit_type<Extended, Fit<Extended, Y>>::value));
EXPECT_TRUE(bool(is_valid_fit_type<Extended, Fit<Adaptable<Extended>, X>>::value));
EXPECT_TRUE(bool(is_valid_fit_type<Extended, Fit<Adaptable<Extended>, Y>>::value));
EXPECT_TRUE(
bool(is_valid_fit_type<Extended, Fit<Adaptable<Extended>, X>>::value));
EXPECT_TRUE(
bool(is_valid_fit_type<Extended, Fit<Adaptable<Extended>, Y>>::value));
EXPECT_FALSE(
bool(is_valid_fit_type<OtherExtended, Fit<Adaptable<Extended>, Y>>::value));
EXPECT_FALSE(
bool(is_valid_fit_type<Extended, Fit<Adaptable<OtherExtended>, Y>>::value));
}

/*
Expand All @@ -238,17 +254,16 @@ class HasMeanPredictImpl {
public:
Eigen::VectorXd predict(const std::vector<X> &,
const Fit<HasMeanPredictImpl> &,
PredictTypeIdentity<Eigen::VectorXd>) const {
PredictTypeIdentity<Eigen::VectorXd>) const {
return Eigen::VectorXd::Zero(0);
}
};

class HasMarginalPredictImpl {
public:
MarginalDistribution
predict(const std::vector<X> &,
const Fit<HasMarginalPredictImpl> &,
PredictTypeIdentity<MarginalDistribution>) const {
predict(const std::vector<X> &, const Fit<HasMarginalPredictImpl> &,
PredictTypeIdentity<MarginalDistribution>) const {
const auto mean = Eigen::VectorXd::Zero(0);
return MarginalDistribution(mean);
}
Expand All @@ -257,8 +272,8 @@ class HasMarginalPredictImpl {
class HasJointPredictImpl {
public:
JointDistribution predict(const std::vector<X> &,
const Fit<HasJointPredictImpl> &,
PredictTypeIdentity<JointDistribution>) const {
const Fit<HasJointPredictImpl> &,
PredictTypeIdentity<JointDistribution>) const {
const auto mean = Eigen::VectorXd::Zero(0);
return JointDistribution(mean);
}
Expand All @@ -267,43 +282,58 @@ class HasJointPredictImpl {
class HasAllPredictImpls {
public:
Eigen::VectorXd predict(const std::vector<X> &,
const Fit<HasAllPredictImpls> &,
PredictTypeIdentity<Eigen::VectorXd>) const {
const Fit<HasAllPredictImpls> &,
PredictTypeIdentity<Eigen::VectorXd>) const {
return Eigen::VectorXd::Zero(0);
}

MarginalDistribution
predict(const std::vector<X> &,
const Fit<HasAllPredictImpls> &,
PredictTypeIdentity<MarginalDistribution>) const {
predict(const std::vector<X> &, const Fit<HasAllPredictImpls> &,
PredictTypeIdentity<MarginalDistribution>) const {
const auto mean = Eigen::VectorXd::Zero(0);
return MarginalDistribution(mean);
}

JointDistribution predict(const std::vector<X> &,
const Fit<HasAllPredictImpls> &,
PredictTypeIdentity<JointDistribution>) const {
const Fit<HasAllPredictImpls> &,
PredictTypeIdentity<JointDistribution>) const {
const auto mean = Eigen::VectorXd::Zero(0);
return JointDistribution(mean);
}
};

TEST(test_traits_core, test_has_valid_predict_impl) {

EXPECT_TRUE(bool(has_valid_predict_mean<HasMeanPredictImpl, X,
Fit<HasMeanPredictImpl>>::value));
EXPECT_FALSE(bool(has_valid_predict_marginal<HasMeanPredictImpl, X, Fit<HasMeanPredictImpl>>::value));
EXPECT_FALSE(bool(has_valid_predict_joint<HasMeanPredictImpl, X, Fit<HasMeanPredictImpl>>::value));


EXPECT_TRUE(bool(has_valid_predict_mean<HasMeanPredictImpl, X, Fit<HasMeanPredictImpl>>::value));
EXPECT_TRUE(
bool(has_valid_predict_marginal<HasMarginalPredictImpl, X, Fit<HasMarginalPredictImpl>>::value));
EXPECT_TRUE(bool(has_valid_predict_joint<HasJointPredictImpl, X, Fit<HasJointPredictImpl>>::value));
EXPECT_TRUE(bool(has_valid_predict_mean<HasAllPredictImpls, X, Fit<HasAllPredictImpls>>::value));
EXPECT_TRUE(bool(has_valid_predict_marginal<HasAllPredictImpls, X, Fit<HasAllPredictImpls>>::value));
EXPECT_TRUE(bool(has_valid_predict_joint<HasAllPredictImpls, X, Fit<HasAllPredictImpls>>::value));
bool(has_valid_predict_marginal<HasMarginalPredictImpl, X,
Fit<HasMarginalPredictImpl>>::value));
EXPECT_FALSE(bool(has_valid_predict_mean<HasMarginalPredictImpl, X, Fit<HasMarginalPredictImpl>>::value));
EXPECT_FALSE(bool(has_valid_predict_joint<HasMarginalPredictImpl, X, Fit<HasMarginalPredictImpl>>::value));


EXPECT_TRUE(bool(has_valid_predict_joint<HasJointPredictImpl, X,
Fit<HasJointPredictImpl>>::value));
EXPECT_FALSE(bool(has_valid_predict_mean<HasJointPredictImpl, X, Fit<HasJointPredictImpl>>::value));
EXPECT_FALSE(bool(has_valid_predict_marginal<HasJointPredictImpl, X, Fit<HasJointPredictImpl>>::value));


EXPECT_TRUE(bool(has_valid_predict_mean<HasAllPredictImpls, X,
Fit<HasAllPredictImpls>>::value));
EXPECT_TRUE(bool(has_valid_predict_marginal<HasAllPredictImpls, X,
Fit<HasAllPredictImpls>>::value));
EXPECT_TRUE(bool(has_valid_predict_joint<HasAllPredictImpls, X,
Fit<HasAllPredictImpls>>::value));
}

class HasName {
public:
std::string name() const {return "name";};
std::string name() const { return "name"; };
};

class HasNoName {};
Expand All @@ -313,5 +343,4 @@ TEST(test_traits_covariance, test_has_name) {
EXPECT_FALSE(bool(has_name<HasNoName>::value));
}


} // namespace albatross

0 comments on commit b157d35

Please sign in to comment.