Skip to content

Commit

Permalink
[cuDNN] [cuDNN v8 API] Support cuDNN Errata Filter (pytorch#73934)
Browse files Browse the repository at this point in the history
Not originally mentioned in the tracking issue pytorch#58414, but is a nice-to-have feature. In summary, the errata filter allows known problematic kernels to be skipped instead of irrecoverably crashing a CUDA context (e.g., via an illegal memory access) via a JSON file supplied at run time. cuDNN frontend description: https://github.com/NVIDIA/cudnn-frontend#errata-filter

Sample errata filter JSON:
```
{
  "version" : 1,
  "rules" : [
    {
      "rule_id" : "avoid_bad_bwd_data",
      "operation" : "ConvBwdData",
      "engine" : 12,
      "cudnn_version_start" : 8000,
      "cudnn_version_end" : 9000
    }
  ]
}
```
CC @ngimel @zasdfgbnm @ptrblck

Pull Request resolved: pytorch#73934
Approved by: https://github.com/ngimel
  • Loading branch information
eqy authored and pytorchmergebot committed Jun 3, 2022
1 parent c29df68 commit fc66521
Showing 1 changed file with 32 additions and 12 deletions.
44 changes: 32 additions & 12 deletions aten/src/ATen/native/cudnn/Conv_v8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,20 @@ size_t get_available_workspace() {
return max_block_size;
}

static nlohmann::json errata_json_handle;

bool plan_errata_exception(const cudnnHandle_t handle, const std::string & executionPlanTag) {
static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, "");
if (!has_json) {
return false;
} else {
return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle, [](){return true;});
}
}

void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::OperationGraph& opGraph, cudnn_frontend::EngineConfigGenerator& generator, const Tensor& x, cudnn_frontend::executionPlans_t& valid_plans, at::DataPtr& workspace_ptr, unsigned int max_plans = 0) {
auto initial_predicate_function = [&](cudnn_frontend::ExecutionPlan const& plan) -> bool {
return false;
return plan_errata_exception(handle, plan.getTag());
};
auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
size_t max_block_size = get_available_workspace();
Expand Down Expand Up @@ -407,8 +418,9 @@ auto get_plans_from_find_fused(const cudnnHandle_t handle,


// We only get configs from this stage to avoid building unnecessary plans that are never executed
auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
auto opGraph = build_opgraph(handle, desc, x, y, w, key, padding, stride, dilation);
opgraph_tag = opGraph.getTag();
auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b() ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
auto sources = get_generator_sources(desc, x, deterministic, allow_tf32, heuristic_mode);

Expand All @@ -417,8 +429,9 @@ auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendD
return configs;
}

auto get_configs_from_heuristics_fused(const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
auto get_configs_from_heuristics_fused(const cudnnHandle_t handle, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
auto opGraph = build_opgraph_fused(handle, x, y, w, z, b, alpha, key, padding, stride, dilation);
opgraph_tag = opGraph.getTag();
auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b() ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
auto sources = get_generator_sources(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, x, deterministic, allow_tf32, heuristic_mode);

Expand Down Expand Up @@ -455,13 +468,16 @@ void try_plans_fused(cudnn_frontend::executionPlans_t& plans, const CacheKeyFuse
TORCH_CHECK(false, "FIND was unable to find an engine to execute this computation");
}

void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) {
void try_configs(cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) {
for (auto & config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config)
.setEngineConfig(config, opgraph_tag)
.build();
if (plan_errata_exception(handle, plan.getTag())) {
continue;
}
run_conv_plan(handle, x, y, w, plan);
benchmark_cache.emplace(key, plan);
return;
Expand All @@ -473,13 +489,16 @@ void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key,
TORCH_CHECK(false, "GET was unable to find an engine to execute this computation");
}

void try_configs_fused(cudnn_frontend::EngineConfigList& configs, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) {
void try_configs_fused(cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) {
for (auto & config : configs) {
try {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(config)
.setEngineConfig(config, opgraph_tag)
.build();
if (plan_errata_exception(handle, plan.getTag())) {
continue;
}
run_conv_plan_fused(handle, x, y, w, z, b, plan);
benchmark_cache_fused.emplace(key, plan);
return;
Expand All @@ -496,7 +515,6 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation,
const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups,
const bool benchmark, const bool deterministic, const bool allow_tf32) {
cudnnHandle_t handle = getCudnnHandle();

CacheKey key;
setCacheKey(key, operation, y, x, w, padding, stride, dilation, groups, deterministic, allow_tf32);
// TODO: is this thread safe if cache is updated? is pointer stale?
Expand All @@ -509,13 +527,14 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation,
cudaGetLastError(); // clear CUDA error
}
}

if (!benchmark) {
std::string opgraph_tag; // extra data needed for errata filter
cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics(handle, operation,
opgraph_tag,
x, y, w, key,
padding, stride, dilation,
deterministic, allow_tf32);
try_configs(configs, key, handle, x, y, w);
try_configs(configs, opgraph_tag, key, handle, x, y, w);
} else {
cudnn_frontend::executionPlans_t plans = get_plans_from_find(handle, operation,
x, y, w, key,
Expand Down Expand Up @@ -544,13 +563,14 @@ void run_fused_conv(const Tensor& x, const Tensor& y, const Tensor& w, const Ten
cudaGetLastError(); // clear CUDA error
}
}

if (!benchmark) {
std::string opgraph_tag; // extra data needed for errata filter
cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics_fused(handle,
opgraph_tag,
x, y, w, z, b, alpha, key,
padding, stride, dilation,
deterministic, allow_tf32);
try_configs_fused(configs, key, handle, x, y, w, z, b);
try_configs_fused(configs, opgraph_tag, key, handle, x, y, w, z, b);
} else {
cudnn_frontend::executionPlans_t plans = get_plans_from_find_fused(handle,
x, y, w, z, b, alpha, key,
Expand Down

0 comments on commit fc66521

Please sign in to comment.