Skip to content

Commit

Permalink
[ Weight ] Add Var32 Tensor in Weight.
Browse files Browse the repository at this point in the history
We will add Var32 Tensor if the Variable Weight is not Full
precision (FP32). This eables the Weight Update with full precision
and only Apply Gradient Process ueses this Tensor. Therefore, the
lifespan of this tensor should be "ApplyGradient".

. Modify TensorPool to generate Weigth considering Mixed Precsion.

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <[email protected]>
  • Loading branch information
jijoongmoon committed May 7, 2024
1 parent 3af2259 commit 8505065
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 26 deletions.
13 changes: 12 additions & 1 deletion nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,18 @@ void NetworkGraph::requestOptimizerVariable(
std::vector<TensorDim> dims = cb(dim);
w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables(
dims, w->getName(), TensorLifespan::MAX_LIFESPAN,
w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS));
w->isGradientClipByGlobalNorm(), w->isMixedPrecision(),
Tensor::Initializer::ZEROS));

if (dim.getDataType() != ml::train::TensorDim::DataType::FP32) {
for (auto &dim : dims)
dim.setDataType(ml::train::TensorDim::DataType::FP32);
w->setOptimizerVariables32(
tensor_manager->requestWeightOptimizerVariables(
dims, w->getName(), TensorLifespan::MAX_LIFESPAN,
w->isGradientClipByGlobalNorm(), w->isMixedPrecision(),
Tensor::Initializer::ZEROS));
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions nntrainer/layers/layer_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ class LayerNode final : public ml::train::Layer, public GraphNode {
const std::vector<TensorDim> getOutputDimensions() const;
/**
* @brief Get the Weight object
* currently, only unittest uses this func.
*
* @param idx Identifier of the weight
* @return Weight& Reference to the weight
Expand Down
31 changes: 26 additions & 5 deletions nntrainer/tensor/manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ std::vector<Weight *> Manager::requestWeights(
// var_exec_order.push_back(TensorPool::PERSIST_END_ORDER);
}

Tensor *var = nullptr, *grad = nullptr;
Tensor *var = nullptr, *grad = nullptr, *var32 = nullptr;
bool is_dependent = !shared_names.empty();
if (is_dependent) {
/// shared_name is used and the orignal name is discarded
Expand All @@ -431,6 +431,17 @@ std::vector<Weight *> Manager::requestWeights(
grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix,
dim_g, grad_exec_order, grad_ls,
Tensor::Initializer::ZEROS);

if (var->getDataType() != ml::train::TensorDim::DataType::FP32) {
TensorDim var32_dim(dim_v);
var32_dim.setDataType(ml::train::TensorDim::DataType::FP32);
std::vector<unsigned int> var32_exec_order;
var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER);

var32 = weight_pool.requestOrExtend(shared_name + ":var32", var32_dim,
var32_exec_order, var_ls,
Tensor::Initializer::ZEROS);
}
}
} else {
/** case requesting fresh weights */
Expand All @@ -448,11 +459,21 @@ std::vector<Weight *> Manager::requestWeights(
grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_g,
grad_exec_order, grad_ls,
Tensor::Initializer::ZEROS, is_wgrad);
if (var->getDataType() != ml::train::TensorDim::DataType::FP32) {
TensorDim var32_dim(dim_v);
var32_dim.setDataType(ml::train::TensorDim::DataType::FP32);
std::vector<unsigned int> var32_exec_order;
var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER);
var32 =
weight_pool.request(name + ":var32", var32_dim, var32_exec_order,
var_ls, Tensor::Initializer::ZEROS);
}
}
}

weights_v2.emplace_back(std::make_unique<Weight>(
var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm));
weights_v2.emplace_back(
std::make_unique<Weight>(var, grad, var32, w_reg, w_reg_const, decay,
is_dependent, clip_by_global_norm));
}

std::transform(weights_v2.begin() + current_size, weights_v2.end(),
Expand Down Expand Up @@ -668,15 +689,15 @@ bool Manager::isSecondLastAccess(const std::string &name,
*/
std::vector<Tensor *> Manager::requestWeightOptimizerVariables(
const std::vector<TensorDim> &dims, const std::string &name,
const TensorLifespan &lifespan, bool is_grad_clip,
const TensorLifespan &lifespan, bool is_grad_clip, bool is_mixed_precision,
Tensor::Initializer initializer) {

std::vector<Tensor *> ret;
ret.reserve(dims.size());

std::vector<unsigned int> exec;
exec.reserve(1);
if (is_grad_clip) {
if (is_grad_clip || is_mixed_precision) {
exec.emplace_back(TensorPool::PERSIST_END_ORDER);
} else {
exec.emplace_back(getMinMaxTensorExecutionOrder(name, true).second);
Expand Down
2 changes: 1 addition & 1 deletion nntrainer/tensor/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class Manager {
*/
std::vector<Tensor *> requestWeightOptimizerVariables(
const std::vector<TensorDim> &dims, const std::string &name,
const TensorLifespan &lifespan, bool is_grad_clip,
const TensorLifespan &lifespan, bool is_grad_clip, bool is_mixed_type,
Tensor::Initializer initializer = Tensor::Initializer::NONE);

/**
Expand Down
83 changes: 83 additions & 0 deletions nntrainer/tensor/weight.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,28 @@ Weight::Weight(const TensorDim &dim, const Tensor::Initializer init,
throw std::invalid_argument("Weight initializer cannot be none");
if (regularizer == WeightRegularizer::UNKNOWN)
throw std::invalid_argument("Weight regularizer unknown");

std::string var32_suffix = ":fp32";
std::string var32_name = name + var32_suffix;

/**
* @note We assume if the Weight Data Type is not FP32, then FP32 Weight is
* necessary to maintain the accuracy.
* We could think it can be other data type and if there is the case to
* support other data type, then the code below needs to be udpated.
*
* Also, the loss_scale is not used in Weight but leave as it is for later
* usage.
*/

if (train && dim.getDataType() != ml::train::TensorDim::DataType::FP32) {
TensorDim var32_dim(dim);
var32_dim.setDataType(ml::train::TensorDim::DataType::FP32);

var32 = std::make_shared<Tensor>(var32_dim, alloc_now_, init, var32_name);
} else {
var32 = std::make_shared<Tensor>(var32_name);
}
}

Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g,
Expand All @@ -52,6 +74,67 @@ Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g,
throw std::invalid_argument("Weight initializer cannot be none");
if (regularizer == WeightRegularizer::UNKNOWN)
throw std::invalid_argument("Weight regularizer unknown");

std::string var32_suffix = ":fp32";
std::string var32_name = name + var32_suffix;

if (train && dim_v.getDataType() != ml::train::TensorDim::DataType::FP32) {
TensorDim var32_dim(dim_v);
var32_dim.setDataType(ml::train::TensorDim::DataType::FP32);
std::string var32_suffix = ":fp32";
std::string var32_name = name + var32_suffix;

var32 = std::make_shared<Tensor>(var32_dim, alloc_now_, init, var32_name);
} else {
var32 = std::make_shared<Tensor>(var32_name);
}
}

Weight::Weight(const Tensor &v, const Tensor &g, const std::string &n,
bool is_dependent, unsigned int output_axis_) :
Var_Grad(v, g, n, is_dependent),
regularizer(WeightRegularizer::NONE),
regularizer_constant(1.0f),
decay(0.0f),
clip_by_global_norm(0.0f),
output_axis(output_axis_),
loss_scale(0.0) {

std::string var32_suffix = ":fp32";
std::string var32_name = n + var32_suffix;

/**
* @note We assume here that Weight is created with variable and gradient
* tensor. It is not copy or clone and, therefore, we do need create var32 if
* it is trainable. For now, We haven't seen the case create wieght with var,
* grad and var32. But we will add weight constructor if there is the cases.
*/

if (!g.empty() && v.getDataType() != ml::train::TensorDim::DataType::FP32) {
TensorDim var32_dim(v.getDim());
var32_dim.setDataType(ml::train::TensorDim::DataType::FP32);

var32 = std::make_shared<Tensor>(var32_dim, true, Tensor::Initializer::NONE,
var32_name);
} else {
var32 = std::make_shared<Tensor>(var32_name);
}
}

Weight::Weight(Tensor *v, Tensor *g, Tensor *v32, const WeightRegularizer reg,
const float reg_const, const float decay, bool is_dependent,
const float max_norm, unsigned int output_axis_,
float loss_scale_) :
Var_Grad(v, g, is_dependent),
regularizer(reg),
regularizer_constant(reg_const),
decay(decay),
clip_by_global_norm(max_norm),
output_axis(output_axis_),
loss_scale(loss_scale_),
var32(std::shared_ptr<Tensor>(v32, [](void *) {})) {
if (!v32)
var32 = std::make_shared<Tensor>();
}

} // namespace nntrainer
50 changes: 31 additions & 19 deletions nntrainer/tensor/weight.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,34 +124,22 @@ class Weight : public Var_Grad {
* if the owner of these tensors free the tensors.
*/
explicit Weight(const Tensor &v, const Tensor &g, const std::string &n = "",
bool is_dependent = false, unsigned int output_axis_ = 3) :
Var_Grad(v, g, n, is_dependent),
regularizer(WeightRegularizer::NONE),
regularizer_constant(1.0f),
decay(0.0f),
clip_by_global_norm(0.0f),
output_axis(output_axis_),
loss_scale(0.0) {}
bool is_dependent = false, unsigned int output_axis_ = 3);

/**
* @brief Construct a new Weight object
*
* @param v ptr to already created variable tensor
* @param g ptr to already created gradient tensor
* @param v32 ptr to already created variable32 tensor
* @param reg Regularizer for the weight
* @param reg_const Constant multiplier for regularizer
*/
explicit Weight(Tensor *v, Tensor *g, const WeightRegularizer reg,
const float reg_const, const float decay,
bool is_dependent = false, const float max_norm = 0.0f,
unsigned int output_axis_ = 3, float loss_scale_ = 0.0f) :
Var_Grad(v, g, is_dependent),
regularizer(reg),
regularizer_constant(reg_const),
decay(decay),
clip_by_global_norm(max_norm),
output_axis(output_axis_),
loss_scale(loss_scale_) {}
explicit Weight(Tensor *v, Tensor *g, Tensor *v32,
const WeightRegularizer reg, const float reg_const,
const float decay, bool is_dependent = false,
const float max_norm = 0.0f, unsigned int output_axis_ = 3,
float loss_scale_ = 0.0f);

/**
* @brief Swap for weight
Expand All @@ -170,6 +158,7 @@ class Weight : public Var_Grad {
swap(lhs.output_axis, rhs.output_axis);
swap(lhs.opt_vars, rhs.opt_vars);
swap(lhs.loss_scale, rhs.loss_scale);
swap(lhs.var32, rhs.var32);
}

/**
Expand Down Expand Up @@ -213,6 +202,8 @@ class Weight : public Var_Grad {
w.var = std::make_shared<Tensor>(this->var->clone());
if (!this->grad->empty())
w.grad = std::make_shared<Tensor>(this->grad->clone());
if (!this->var32->empty())
w.var32 = std::make_shared<Tensor>(this->var32->clone());

return w;
}
Expand All @@ -230,6 +221,16 @@ class Weight : public Var_Grad {
opt_vars = tensors;
}

/**
* @brief Add optimizer variables32
* We assume if the datatype of weight is not FP32, then it needs to set
* OptmizerVarialbe32 to maintain acccuracy.
* @param tensors OptimizerVariable32 Tensor list
*/
void setOptimizerVariables32(std::vector<Tensor *> tensors) {
opt_vars32 = tensors;
}

/**
* @brief Get optimizer variable reference
* @param idx Index of the optimizer variable to get
Expand Down Expand Up @@ -316,6 +317,16 @@ class Weight : public Var_Grad {
return clip_by_global_norm > epsilon;
}

/**
* @brief Check if the variable type is not full precision
*
* @return true if it is not full precsion
* @return false otherwise
*/
bool isMixedPrecision() const {
return var->getDataType() == ml::train::TensorDim::DataType::FP32;
}

/**
* @brief clip the gradient value based on the given global norm
*
Expand All @@ -338,6 +349,7 @@ class Weight : public Var_Grad {
unsigned int output_axis;
float loss_scale;
std::vector<Tensor *> opt_vars; /**< optimizer variables */
std::vector<Tensor *> opt_vars32;
std::shared_ptr<Tensor> var32;

/**
Expand Down

0 comments on commit 8505065

Please sign in to comment.