From e537fd86cd63ebe86b1fb7f44eb72ed0bf60b6cc Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Mon, 8 Jan 2024 14:10:47 -0500 Subject: [PATCH] [SPIRV] Implement out variables in patch control functions The documentation for the patch control functions only mentions > The outputs are usually defined by a structure and is identified by HS_CONSTANT_DATA_OUTPUT in this example; the structure depends on the domain type and would be different for triangle or isoline domains. This is why we did not realize that we needed to implement out variables as well. This commit implement the alternate method of having an output. Fixes #3743 --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 68 +++++++++++++------ tools/clang/lib/SPIRV/SpirvEmitter.h | 12 +++- .../hs.const.output-patch.out.hlsl | 39 +++++++++++ 3 files changed, 95 insertions(+), 24 deletions(-) create mode 100644 tools/clang/test/CodeGenSPIRV/hs.const.output-patch.out.hlsl diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 02f591e836..69f5318c63 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -13409,25 +13409,6 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF( const QualType pcfRetType = patchConstFunc->getReturnType(); std::vector pcfParams; - - // A lambda for creating a stage input variable and its associated temporary - // variable for function call. Also initializes the temporary variable using - // the contents loaded from the stage input variable. Returns the - // of the temporary variable. - const auto createParmVarAndInitFromStageInputVar = - [this](const ParmVarDecl *param) { - const QualType type = param->getType(); - std::string tempVarName = "param.var." + param->getNameAsString(); - auto paramLoc = param->getLocation(); - auto *tempVar = spvBuilder.addFnVar( - type, paramLoc, tempVarName, param->hasAttr(), - param->hasAttr()); - SpirvInstruction *loadedValue = nullptr; - declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true); - spvBuilder.createStore(tempVar, loadedValue, paramLoc); - return tempVar; - }; - for (const auto *param : patchConstFunc->parameters()) { // Note: According to the HLSL reference, the PCF takes an InputPatch of // ControlPoints as well as the PatchID (PrimitiveID). This does not @@ -13439,14 +13420,19 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF( pcfParams.push_back(hullMainOutputPatch); } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) { if (!primitiveId) { - primitiveId = createParmVarAndInitFromStageInputVar(param); + primitiveId = createPCFParmVarAndInitFromStageInputVar(param); } pcfParams.push_back(primitiveId); } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::ViewID)) { if (!viewId) { - viewId = createParmVarAndInitFromStageInputVar(param); + viewId = createPCFParmVarAndInitFromStageInputVar(param); } pcfParams.push_back(viewId); + } else if (param->hasAttr()) { + // Create a temporary function scope variable to pass to the PCF function + // for the output. The value of this variable should be copied to an + // output variable for the param after the function call. + pcfParams.push_back(createFunctionScopeTempFromParameter(param)); } else { emitError("patch constant function parameter '%0' unknown", param->getLocation()) @@ -13459,6 +13445,18 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF( /*forPCF*/ true)) return false; + // Traverall all of the parameters for the PCF and copy out all of the output + // variables. + for (uint32_t idx = 0; idx < patchConstFunc->parameters().size(); idx++) { + const auto *param = patchConstFunc->parameters()[idx]; + if (param->hasAttr()) { + SpirvInstruction *pcfParam = pcfParams[idx]; + SpirvInstruction *loadedValue = spvBuilder.createLoad( + pcfParam->getAstResultType(), pcfParam, param->getLocation()); + declIdMapper.createStageOutputVar(param, loadedValue, /*forPCF*/ true); + } + } + spvBuilder.createBranch(mergeBB, locEnd); spvBuilder.addSuccessor(mergeBB); spvBuilder.setInsertPoint(mergeBB); @@ -14419,6 +14417,34 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() { spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation()); } +SpirvVariable *SpirvEmitter::createPCFParmVarAndInitFromStageInputVar( + const ParmVarDecl *param) { + const QualType type = param->getType(); + std::string tempVarName = "param.var." + param->getNameAsString(); + auto paramLoc = param->getLocation(); + auto *tempVar = spvBuilder.addFnVar( + type, paramLoc, tempVarName, param->hasAttr(), + param->hasAttr()); + SpirvInstruction *loadedValue = nullptr; + declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true); + spvBuilder.createStore(tempVar, loadedValue, paramLoc); + return tempVar; +} + +SpirvVariable * +SpirvEmitter::createFunctionScopeTempFromParameter(const ParmVarDecl *param) { + const QualType type = param->getType(); + std::string tempVarName = "param.var." + param->getNameAsString(); + auto paramLoc = param->getLocation(); + auto *tempVar = spvBuilder.addFnVar( + type, paramLoc, tempVarName, param->hasAttr(), + param->hasAttr()); + // SpirvInstruction *loadedValue = + // spvBuilder.createLoad(tempVar->getAstResultType(), tempVar, paramLoc); + // declIdMapper.createStageOutputVar(param, nullptr, /*forPCF*/ true); + return tempVar; +} + bool SpirvEmitter::spirvToolsTrimCapabilities(std::vector *mod, std::string *messages) { spvtools::Optimizer optimizer(featureManager.getTargetEnv()); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 37791c1aa5..88dbe3b522 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -979,7 +979,6 @@ class SpirvEmitter : public ASTConsumer { /// statement. void processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt); -private: /// Handles the offset argument in the given method call at the given argument /// index. Panics if the argument at the given index does not exist. Writes /// the to either *constOffset or *varOffset, depending on the @@ -1157,7 +1156,6 @@ class SpirvEmitter : public ASTConsumer { const CXXMethodDecl *memberFn, SourceLocation loc); -private: /// \brief Takes a vector of size 4, and returns a vector of size 1 or 2 or 3 /// or 4. Creates a CompositeExtract or VectorShuffle instruction to extract /// a scalar or smaller vector from the beginning of the input vector if @@ -1195,7 +1193,6 @@ class SpirvEmitter : public ASTConsumer { /// execution mode, if it has not already been added. void beginInvocationInterlock(SourceLocation loc, SourceRange range); -private: /// \brief If the given FunctionDecl is not already in the workQueue, creates /// a FunctionInfo object for it, and inserts it into the workQueue. It also /// updates the functionInfoMap with the proper mapping. @@ -1241,6 +1238,15 @@ class SpirvEmitter : public ASTConsumer { /// https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html. void addDerivativeGroupExecutionMode(); + /// Creates an input variable for `param` that will be used by the PCF. The + /// parameter is also added to the PCF. The wrapper function will copy the + /// input variable to the parameter. + SpirvVariable * + createPCFParmVarAndInitFromStageInputVar(const ParmVarDecl *param); + + /// Returns a function scope parameter with the same type as |param|. + SpirvVariable *createFunctionScopeTempFromParameter(const ParmVarDecl *param); + public: /// \brief Wrapper method to create a fatal error message and report it /// in the diagnostic engine associated with this consumer. diff --git a/tools/clang/test/CodeGenSPIRV/hs.const.output-patch.out.hlsl b/tools/clang/test/CodeGenSPIRV/hs.const.output-patch.out.hlsl new file mode 100644 index 0000000000..6bbcdd3764 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/hs.const.output-patch.out.hlsl @@ -0,0 +1,39 @@ +// RUN: %dxc -T hs_6_0 -E Hull -fcgl %s -spirv | FileCheck %s + +struct ControlPoint { float4 position : POSITION; }; + +// CHECK: %param_var_edge = OpVariable %_ptr_Function__arr_float_uint_3 Function +// CHECK: %param_var_inside = OpVariable %_ptr_Function_float Function +// CHECK: %param_var_myFloat = OpVariable %_ptr_Function_float Function +// CHECK: OpFunctionCall %void %HullConst %param_var_edge %param_var_inside %param_var_myFloat +// CHECK: [[edges:%[0-9]+]] = OpLoad %_arr_float_uint_3 %param_var_edge +// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_0 +// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 0 +// CHECK: OpStore [[addr]] [[val]] +// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_1 +// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 1 +// CHECK: OpStore [[addr]] [[val]] +// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_2 +// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 2 +// CHECK: OpStore [[addr]] [[val]] +// CHECK: [[val:%[0-9]+]] = OpLoad %float %param_var_inside +// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelInner %uint_0 +// CHECK: OpStore [[addr]] [[val]] +// CHECK: [[val:%[0-9]+]] = OpLoad %float %param_var_myFloat +// CHECK: OpStore %out_var_MY_FLOAT [[val]] + +void HullConst (out float edge [3] : SV_TessFactor, out float inside : SV_InsideTessFactor, out float myFloat : MY_FLOAT) +{ + edge[0] = 2; + edge[1] = 2; + edge[2] = 2; + inside = 2; + myFloat = .2; +} + +[domain("tri")] +[partitioning("fractional_odd")] +[outputtopology("triangle_ccw")] +[patchconstantfunc("HullConst")] +[outputcontrolpoints(3)] +ControlPoint Hull (InputPatch v, uint id : SV_OutputControlPointID) { return v[id]; }