Skip to content

Commit

Permalink
[22/N] Fix clang-tidy warnings in jit (pytorch#134829)
Browse files Browse the repository at this point in the history
Follows  pytorch#134537

Pull Request resolved: pytorch#134829
Approved by: https://github.com/ezyang
  • Loading branch information
cyyever authored and pytorchmergebot committed Sep 19, 2024
1 parent b71802f commit 7bbdf87
Show file tree
Hide file tree
Showing 26 changed files with 34 additions and 81 deletions.
1 change: 0 additions & 1 deletion torch/csrc/jit/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,6 @@ std::string Module::dump_to_str(
std::stringstream parameters_ss;
std::stringstream attributes_ss;
std::stringstream methods_ss;
std::stringstream submodules_ss;

for (const NameTensor& p : named_parameters(/*recurse=*/false)) {
parameters_ss << p.name << " = ";
Expand Down
10 changes: 3 additions & 7 deletions torch/csrc/jit/backends/backend_detail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
#include <stack>
#include <unordered_map>

namespace torch {
namespace jit {
namespace detail {
namespace torch::jit::detail {
namespace {

/*
Expand Down Expand Up @@ -361,7 +359,7 @@ Module codegen_backend_module(

wrapper_method_te.v("def_inputs", def_inputs);
wrapper_method_te.v("fwd_inputs", fwd_inputs);
wrapper_methods.push_back(wrapper_method_ct.format(wrapper_method_te));
wrapper_methods.emplace_back(wrapper_method_ct.format(wrapper_method_te));

// If the output type is a single element tuple then add an extra comma
// to ensure the final output maintains this type.
Expand Down Expand Up @@ -408,6 +406,4 @@ Module codegen_backend_module(

return wrapper;
}
} // namespace detail
} // namespace jit
} // namespace torch
} // namespace torch::jit::detail
8 changes: 3 additions & 5 deletions torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>

namespace torch {
namespace jit {
namespace torch::jit {

// Implementation of Android NNAPI Backend delegate

Expand Down Expand Up @@ -107,7 +106,7 @@ class NnapiBackend : public PyTorchBackendInterface {

// Runs once per model initialization
// Cannot be moved to compile(), because init() requires actual inputs
void init(c10::IValue handle, c10::List<at::Tensor> inputs) {
void init(const c10::IValue& handle, const c10::List<at::Tensor>& inputs) {
TORCH_CHECK(comp_ == nullptr);
auto dict = handle.toGenericDict();

Expand All @@ -134,5 +133,4 @@ constexpr auto backend_name = "nnapi";
static auto cls = torch::jit::backend<NnapiBackend>(backend_name);
} // namespace

} // namespace jit
} // namespace torch
} // namespace torch::jit
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/fuser/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ std::shared_ptr<FusedKernel> compileKernel(
const KernelSpec& spec,
const ArgSpec& arg_spec,
const std::vector<int64_t>& map_size,
const at::Device device) {
const at::Device& device) {
const std::vector<TensorDesc>& input_desc = arg_spec.descs();

auto graph = spec.graph()->copy();
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/fuser/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ TORCH_API std::shared_ptr<FusedKernel> compileKernel(
const KernelSpec& spec,
const ArgSpec& arg_spec,
const std::vector<int64_t>& map_size,
const at::Device device);
const at::Device& device);

TORCH_API size_t nCompiledKernels();

Expand Down
10 changes: 2 additions & 8 deletions torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
#include <iostream>
#include <string>

namespace torch {
namespace jit {
namespace fuser {
namespace cpu {
namespace torch::jit::fuser::cpu {

#ifdef _MSC_VER
static const std::string getTempPath() {
Expand Down Expand Up @@ -357,7 +354,4 @@ static std::shared_ptr<FusedKernel> createFusionKernel(
}

RegisterFusionBackend reg(DeviceType::CPU, createFusionKernel);
} // namespace cpu
} // namespace fuser
} // namespace jit
} // namespace torch
} // namespace torch::jit::fuser::cpu
5 changes: 0 additions & 5 deletions torch/csrc/jit/mobile/compatibility/model_compatibility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,6 @@ std::unordered_set<std::string> _get_mobile_model_contained_types(
std::unordered_set<std::string> _get_mobile_model_contained_types(
const std::vector<IValue>& bytecode_ivalues) {
std::unordered_set<std::string> contained_types;
// To avoid parsing same type twice, declare $parsed_type_names_records and
// use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as
// the hash to record which types are parsed.
std::unordered_set<std::string> parsed_type_names_records;
for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
auto type_table_tuple =
Expand All @@ -299,7 +295,6 @@ std::unordered_set<std::string> _get_mobile_model_contained_types(
// for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
std::vector<std::string> type_name_list;
for (const auto& type_definition : type_table) {
std::unordered_set<std::string> type_tokens;
std::string type_name = type_definition.toStringRef();
type_name_list.emplace_back(type_name);
}
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/passes/create_functional_graphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ struct FunctionalGraphSlicer {
graph_->createWithSubgraph(prim::FunctionalGraph)
->insertBefore(block->return_node());
auto reverse_iter = block->nodes().reverse();
std::vector<Value*> graph_outputs;
for (auto it = reverse_iter.begin(); it != reverse_iter.end();) {
Node* n = *it++;

Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/passes/lower_tuples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ static void flattenTupleInLoopParams(Node* n, size_t index) {
Block* block = n->blocks().at(0);
Node* block_node = n;

std::vector<Value*> new_node_inputs = {};
auto new_construct_node =
block->prependNode(block->owningGraph()->create(prim::TupleConstruct));
for (size_t j = 0; j < tt->elements().size(); ++j) {
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ void ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph>& graph) {
Node* pattern_choose_qparam = choose_qparam_val->node();

std::vector<DynamicQuantOps> nodes_to_rewrite;
std::vector<Node*> choose_qparam_nodes_to_rewrite;
for (const Match& match : matches) {
Node* matched_dequantize = match.nodes_map.at(pattern_dequant);
Node* matched_quantize = match.nodes_map.at(pattern_quant);
Expand Down Expand Up @@ -1557,7 +1556,6 @@ QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams(
"getQSchemeAndParamMap expects the corresponding observer for ",
v->debugName(),
" exists.");
std::vector<Value*> qparams_graph_values;
QuantOpParams quant_op_params;

TORCH_CHECK(
Expand Down
7 changes: 0 additions & 7 deletions torch/csrc/jit/passes/quantization/quantization_patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -808,13 +808,6 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
"%count_include_pad",
"%divisor_override"});

std::string common_general_value_op = R"(
%r_scale : float = aten::q_scale(%a_quant)
%r_zero_point : int = aten::q_zero_point(%a_quant)
%r_dtype : int = prim::dtype(%a_quant)
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
return (%r_quant) )";

auto avg_pool3d = getInputTensorQParamOpFusionInfo(
"aten::avg_pool3d",
{"%kernel_size",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ std::unordered_set<std::string> RegisterPrePackParams(
int64_t uid = 0; // int + method name gives unique identifier
auto graph = m.get_method(method_name).graph();
std::stack<Block*> blocks_to_visit;
std::unordered_set<Node*> nodes_to_delete;
blocks_to_visit.push(graph->block());
std::string attr_name_base =
attr_prefix + "_" + method_name + "_ondevice_ptq_packed_weight_";
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/passes/utils/subgraph_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
}
++it;

std::vector<Node*> merged_nodes;
while (it != end_it) {
Node* node = *it;
++it;
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,6 @@ ManagedTensorRanges::ManagedTensorRanges(
const AliasDb& alias_db,
const c10::FastSet<const Value*>& managed_tensor_values) {
const std::vector<Node*> nodes(block.nodes().begin(), block.nodes().end());
const c10::FastSet<const Value*> graph_inputs(
block.inputs().begin(), block.inputs().end());

const auto num_nodes = static_cast<uint32_t>(nodes.size());
for (const auto i : c10::irange(num_nodes)) {
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/serialization/export_bytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ mobile::Code compileGraphToMobileCode(

// operator names
std::vector<std::string> method_names;
std::vector<int64_t> op_debug_handles;
int next_new_op_index = 0;

auto op_to_specified_args = code.op_to_num_specified_args();
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/serialization/flatbuffer_serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,6 @@ flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
} else {
size_t num_attr = class_ptr->numAttributes();
std::vector<flatbuffers::Offset<flatbuffers::String>> names;
std::vector<uint32_t> type_index;
for (size_t i = 0; i < num_attr; ++i) {
names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i)));
}
Expand Down
7 changes: 3 additions & 4 deletions torch/csrc/jit/tensorexpr/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,12 @@ class RegisterCodeGen {
RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance();
codegen_list.AddStmtFactoryMethod(
name,
[](StmtPtr stmt,
[](const StmtPtr& stmt,
const std::vector<CodeGen::BufferArg>& params,
at::Device device,
const std::string& kernel_func_name) {
std::unique_ptr<CodeGen> method(
new CodeGenType(stmt, params, device, kernel_func_name));
return method;
return std::make_unique<CodeGenType>(
stmt, params, device, kernel_func_name);
});
}
};
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/tensorexpr/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
v->buffer_var()->name_hint());
}
buffer_mapping_[b] = buffer->data();
internal_buffers_.insert(std::make_pair(b, std::move(buffer)));
internal_buffers_.emplace(std::move(b), std::move(buffer));
}

void visit(const PlacementAllocatePtr& v) override {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/tensorexpr/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ bool Buf::is_stride_one(int cur_dim) const {
return exprEquals(strides_[cur_dim], alloc<LongImm>(1));
}

ExprHandle expr_to_vec(ExprHandle v, int lanes) {
ExprHandle expr_to_vec(const ExprHandle& v, int lanes) {
if (lanes == 1) {
return v;
} else {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/tensorexpr/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,6 @@ TORCH_API ExprHandle Relu(const ExprHandle& v1);
TORCH_API ExprHandle
ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f);

TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes);
TORCH_API ExprHandle expr_to_vec(const ExprHandle& v, int lanes);

} // namespace torch::jit::tensorexpr
2 changes: 0 additions & 2 deletions torch/csrc/jit/tensorexpr/ir_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2885,7 +2885,6 @@ ExprPtr SimplifierUnderContext::mutate(const DivPtr& v) {
ExprPtr lhs = v->lhs();
ExprPtr rhs = v->rhs();

std::ostringstream oss;
if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) {
GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
return ret->accept_mutator(this);
Expand Down Expand Up @@ -3005,7 +3004,6 @@ ExprPtr SimplifierUnderContext::mutate(const ModPtr& v) {
ExprPtr lhs = v->lhs();
ExprPtr rhs = v->rhs();

std::ostringstream oss;
if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) {
GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret);
return ret->accept_mutator(this);
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,6 @@ TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice(
// we use the debug names in printing cuda code, they need to be removed
// of characters that can't be used in a variable identifier
void TensorExprKernel::genInputDebugNames() {
std::unordered_map<std::string, const torch::jit::Value*> name_to_value;
std::unordered_set<std::string> name_set;
std::unordered_map<const torch::jit::Value*, std::string> value_to_name;
for (const torch::jit::Value* input : graph_->inputs()) {
Expand Down Expand Up @@ -1747,7 +1746,6 @@ void TensorExprKernel::compile() {
VarPtr v = t.buf()->base_handle();
scalars_[output] = VarHandle(v);
block->append_stmt(t.stmt());
std::vector<ExprPtr> dims;
BufHandle buf(
"scalar_" + sanitizeName(output->debugName()), {}, v->dtype());
StmtPtr store = Store::make(buf, {}, ExprHandle(v));
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/tensorexpr/loopnest_randomization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {
}

int index = rand() % (int)all_nested_loops.size();
auto nested_loops = all_nested_loops.at(index);
auto const& nested_loops = all_nested_loops.at(index);
if (nested_loops.size() < 2) {
break;
}
Expand Down Expand Up @@ -554,7 +554,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) {

// Randomly pick a set of consecutive loops to flatten
int index = rand() % (int)all_nested_loops.size();
auto nested_loops = all_nested_loops.at(index);
auto const& nested_loops = all_nested_loops.at(index);

// Generate a good history message
std::vector<std::string> indices;
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/tensorexpr/operators/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ Tensor computeMax(
}
BufHandle ResultBuf("max", outputShape, dtype);
BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
std::vector<ExprHandle> max_dims_expr;
auto max_dim = std::get<int64_t>(inputs[1]);
auto keep_dim = std::get<bool>(inputs[2]);
return Tensor(
Expand Down
Loading

0 comments on commit 7bbdf87

Please sign in to comment.