diff --git a/include/qnets/templ/TemplNet.hpp b/include/qnets/templ/TemplNet.hpp index 3122dd8..4069059 100644 --- a/include/qnets/templ/TemplNet.hpp +++ b/include/qnets/templ/TemplNet.hpp @@ -203,7 +203,7 @@ class TemplNet template - typename std::enable_if::type _computeInputGradients() + typename std::enable_if::type _computeInputGradients() { throw std::runtime_error("[TemplNet::_processOrigInput] Original input derivatives require provided input-to-orig derivatives."); } @@ -245,7 +245,7 @@ class TemplNet } template - typename std::enable_if::type _processOrigInput() + typename std::enable_if::type _processOrigInput() { // feed original input std::get<0>(_layers).ForwardInput(_input, dflags); @@ -253,13 +253,22 @@ class TemplNet if (this->hasD1()) { this->_computeInputGradients(); } } + template + typename std::enable_if::type _processOrigInput() + { + // feed original input + std::get<0>(_layers).ForwardInput(_input, dflags); + this->_propagateLayers(); + } + template typename std::enable_if::type _processOrigInput() { throw std::runtime_error("[TemplNet::_processOrigInput] Original input can't be fed directly, because it differs in size from network input."); } - void _processDerivInput(const ValueT orig_d1[], const ValueT orig_d2[]) + template + typename std::enable_if::type _processDerivInput(const ValueT orig_d1[], const ValueT orig_d2[]) { // feed derived network input std::get<0>(_layers).ForwardLayer(_input.data(), orig_d1, orig_d2, dflags); @@ -267,6 +276,14 @@ class TemplNet if (this->hasD1()) { this->_computeInputGradients(orig_d1); } } + template + typename std::enable_if::type _processDerivInput(const ValueT orig_d1[], const ValueT orig_d2[]) + { + // feed derived network input + std::get<0>(_layers).ForwardLayer(_input.data(), orig_d1, orig_d2, dflags); + this->_propagateLayers(); + } + public: explicit constexpr TemplNet(DynamicDFlags init_dflags = DynamicDFlags{DCONF}): _out_begins(tupl::make_fcont>(_layers, [](const auto &layer) { return &layer.out().front(); })),