Skip to content

Commit

Permalink
Merge pull request #83 from akleeman/crtp_model_base
Browse files Browse the repository at this point in the history
CRTP: ModelBase
  • Loading branch information
akleeman authored Mar 7, 2019
2 parents 9747537 + cf74a7e commit 93898b6
Showing 7 changed files with 563 additions and 327 deletions.
77 changes: 50 additions & 27 deletions albatross/core/declarations.h
Original file line number Diff line number Diff line change
@@ -13,13 +13,6 @@
#ifndef ALBATROSS_CORE_DECLARATIONS_H
#define ALBATROSS_CORE_DECLARATIONS_H

#include <functional>
#include <map>
#include <memory>
#include <vector>

#include <Eigen/Core>

namespace Eigen {

template <typename _Scalar, int SizeAtCompileTime>
@@ -31,15 +24,32 @@ namespace albatross {
/*
* Model
*/
template <typename FeatureType> class RegressionModel;
template <typename ModelType> class ModelBase;

template <typename FeatureType> struct RegressionDataset;
template <typename FeatureType> struct RegressionFold;
template <typename FeatureType, typename FitType>
class SerializableRegressionModel;

template <typename FeatureType>
using RegressionModelCreator =
std::function<std::unique_ptr<RegressionModel<FeatureType>>()>;
template <typename T> struct PredictTypeIdentity;

template <typename ModelType, typename FeatureType, typename FitType> class Prediction;

template <typename ModelType, typename FitType> class FitModel;

template <typename ModelType, typename FeatureType=void> class Fit {};

/*
* Parameter Handling
*/
class Prior;
struct Parameter;

using ParameterKey = std::string;
// If you change the way these are stored, be sure there's
// a corresponding cereal type included or you'll get some
// really impressive compilation errors.
using ParameterPrior = std::shared_ptr<Prior>;
using ParameterValue = double;

using ParameterStore = std::map<ParameterKey, Parameter>;

/*
* Distributions
@@ -51,27 +61,40 @@ using DiagonalMatrixXd =
Eigen::SerializableDiagonalMatrix<double, Eigen::Dynamic>;
using MarginalDistribution = Distribution<DiagonalMatrixXd>;

/*
* Models
*/
template <typename CovarianceFunc, typename ImplType>
class GaussianProcessBase;

template <typename CovarianceFunc>
class GaussianProcessRegression;

struct NullLeastSquaresImpl {};

template <typename ImplType = NullLeastSquaresImpl>
class LeastSquares;



/*
* Cross Validation
*/
using FoldIndices = std::vector<std::size_t>;
using FoldName = std::string;
using FoldIndexer = std::map<FoldName, FoldIndices>;
template <typename FeatureType>
using IndexerFunction =
using FoldIndices = std::vector<std::size_t>;
using FoldName = std::string;
using FoldIndexer = std::map<FoldName, FoldIndices>;

template <typename FeatureType>
using IndexerFunction =
std::function<FoldIndexer(const RegressionDataset<FeatureType> &)>;

template <typename ModelType>
class CrossValidation;

/*
* RANSAC
*/
template <typename ModelType, typename FeatureType> class GenericRansac;
template <typename FeatureType, typename ModelType>
std::unique_ptr<GenericRansac<ModelType, FeatureType>>
make_generic_ransac_model(ModelType *model, double inlier_threshold,
std::size_t min_inliers,
std::size_t random_sample_size,
std::size_t max_iterations,
const IndexerFunction<FeatureType> &indexer_function);
template <typename ModelType, typename FeatureType> class Ransac;
}

#endif
44 changes: 44 additions & 0 deletions albatross/core/fit_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (C) 2019 Swift Navigation Inc.
* Contact: Swift Navigation <dev@swiftnav.com>
*
* This source is subject to the license found in the file 'LICENSE' which must
* be distributed together with this source. All other rights reserved.
*
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
*/

#ifndef ALBATROSS_CORE_FIT_MODEL_H
#define ALBATROSS_CORE_FIT_MODEL_H

namespace albatross {

template <typename ModelType, typename Fit>
class FitModel {
public:
template <typename X, typename Y, typename Z>
friend class Prediction;

static_assert(std::is_move_constructible<Fit>::value,
"Fit type must be move constructible to avoid unexpected copying.");

FitModel(const ModelType &model,
Fit &&fit)
: model_(model), fit_(std::move(fit)) {}

template <typename PredictFeatureType>
Prediction<ModelType, PredictFeatureType, Fit>
get_prediction(const std::vector<PredictFeatureType> &features) const {
return Prediction<ModelType, PredictFeatureType, Fit>(*this, features);
}

private:
const ModelType model_;
const Fit fit_;

};

}
#endif
Loading

0 comments on commit 93898b6

Please sign in to comment.