diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 02f591e836..69f5318c63 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -13409,25 +13409,6 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF( const QualType pcfRetType = patchConstFunc->getReturnType(); std::vector pcfParams; - - // A lambda for creating a stage input variable and its associated temporary - // variable for function call. Also initializes the temporary variable using - // the contents loaded from the stage input variable. Returns the - // of the temporary variable. - const auto createParmVarAndInitFromStageInputVar = - [this](const ParmVarDecl *param) { - const QualType type = param->getType(); - std::string tempVarName = "param.var." + param->getNameAsString(); - auto paramLoc = param->getLocation(); - auto *tempVar = spvBuilder.addFnVar( - type, paramLoc, tempVarName, param->hasAttr(), - param->hasAttr()); - SpirvInstruction *loadedValue = nullptr; - declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true); - spvBuilder.createStore(tempVar, loadedValue, paramLoc); - return tempVar; - }; - for (const auto *param : patchConstFunc->parameters()) { // Note: According to the HLSL reference, the PCF takes an InputPatch of // ControlPoints as well as the PatchID (PrimitiveID). This does not @@ -13439,14 +13420,19 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF( pcfParams.push_back(hullMainOutputPatch); } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) { if (!primitiveId) { - primitiveId = createParmVarAndInitFromStageInputVar(param); + primitiveId = createPCFParmVarAndInitFromStageInputVar(param); } pcfParams.push_back(primitiveId); } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::ViewID)) { if (!viewId) { - viewId = createParmVarAndInitFromStageInputVar(param); + viewId = createPCFParmVarAndInitFromStageInputVar(param); } pcfParams.push_back(viewId); + } else if (param->hasAttr()) { + // Create a temporary function scope variable to pass to the PCF function + // for the output. The value of this variable should be copied to an + // output variable for the param after the function call. + pcfParams.push_back(createFunctionScopeTempFromParameter(param)); } else { emitError("patch constant function parameter '%0' unknown", param->getLocation()) @@ -13459,6 +13445,18 @@ bool SpirvEmitter::processHSEntryPointOutputAndPCF( /*forPCF*/ true)) return false; + // Traverall all of the parameters for the PCF and copy out all of the output + // variables. + for (uint32_t idx = 0; idx < patchConstFunc->parameters().size(); idx++) { + const auto *param = patchConstFunc->parameters()[idx]; + if (param->hasAttr()) { + SpirvInstruction *pcfParam = pcfParams[idx]; + SpirvInstruction *loadedValue = spvBuilder.createLoad( + pcfParam->getAstResultType(), pcfParam, param->getLocation()); + declIdMapper.createStageOutputVar(param, loadedValue, /*forPCF*/ true); + } + } + spvBuilder.createBranch(mergeBB, locEnd); spvBuilder.addSuccessor(mergeBB); spvBuilder.setInsertPoint(mergeBB); @@ -14419,6 +14417,34 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() { spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation()); } +SpirvVariable *SpirvEmitter::createPCFParmVarAndInitFromStageInputVar( + const ParmVarDecl *param) { + const QualType type = param->getType(); + std::string tempVarName = "param.var." + param->getNameAsString(); + auto paramLoc = param->getLocation(); + auto *tempVar = spvBuilder.addFnVar( + type, paramLoc, tempVarName, param->hasAttr(), + param->hasAttr()); + SpirvInstruction *loadedValue = nullptr; + declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true); + spvBuilder.createStore(tempVar, loadedValue, paramLoc); + return tempVar; +} + +SpirvVariable * +SpirvEmitter::createFunctionScopeTempFromParameter(const ParmVarDecl *param) { + const QualType type = param->getType(); + std::string tempVarName = "param.var." + param->getNameAsString(); + auto paramLoc = param->getLocation(); + auto *tempVar = spvBuilder.addFnVar( + type, paramLoc, tempVarName, param->hasAttr(), + param->hasAttr()); + // SpirvInstruction *loadedValue = + // spvBuilder.createLoad(tempVar->getAstResultType(), tempVar, paramLoc); + // declIdMapper.createStageOutputVar(param, nullptr, /*forPCF*/ true); + return tempVar; +} + bool SpirvEmitter::spirvToolsTrimCapabilities(std::vector *mod, std::string *messages) { spvtools::Optimizer optimizer(featureManager.getTargetEnv()); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 37791c1aa5..88dbe3b522 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -979,7 +979,6 @@ class SpirvEmitter : public ASTConsumer { /// statement. void processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt); -private: /// Handles the offset argument in the given method call at the given argument /// index. Panics if the argument at the given index does not exist. Writes /// the to either *constOffset or *varOffset, depending on the @@ -1157,7 +1156,6 @@ class SpirvEmitter : public ASTConsumer { const CXXMethodDecl *memberFn, SourceLocation loc); -private: /// \brief Takes a vector of size 4, and returns a vector of size 1 or 2 or 3 /// or 4. Creates a CompositeExtract or VectorShuffle instruction to extract /// a scalar or smaller vector from the beginning of the input vector if @@ -1195,7 +1193,6 @@ class SpirvEmitter : public ASTConsumer { /// execution mode, if it has not already been added. void beginInvocationInterlock(SourceLocation loc, SourceRange range); -private: /// \brief If the given FunctionDecl is not already in the workQueue, creates /// a FunctionInfo object for it, and inserts it into the workQueue. It also /// updates the functionInfoMap with the proper mapping. @@ -1241,6 +1238,15 @@ class SpirvEmitter : public ASTConsumer { /// https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html. void addDerivativeGroupExecutionMode(); + /// Creates an input variable for `param` that will be used by the PCF. The + /// parameter is also added to the PCF. The wrapper function will copy the + /// input variable to the parameter. + SpirvVariable * + createPCFParmVarAndInitFromStageInputVar(const ParmVarDecl *param); + + /// Returns a function scope parameter with the same type as |param|. + SpirvVariable *createFunctionScopeTempFromParameter(const ParmVarDecl *param); + public: /// \brief Wrapper method to create a fatal error message and report it /// in the diagnostic engine associated with this consumer. diff --git a/tools/clang/test/CodeGenSPIRV/hs.const.output-patch.out.hlsl b/tools/clang/test/CodeGenSPIRV/hs.const.output-patch.out.hlsl new file mode 100644 index 0000000000..6bbcdd3764 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/hs.const.output-patch.out.hlsl @@ -0,0 +1,39 @@ +// RUN: %dxc -T hs_6_0 -E Hull -fcgl %s -spirv | FileCheck %s + +struct ControlPoint { float4 position : POSITION; }; + +// CHECK: %param_var_edge = OpVariable %_ptr_Function__arr_float_uint_3 Function +// CHECK: %param_var_inside = OpVariable %_ptr_Function_float Function +// CHECK: %param_var_myFloat = OpVariable %_ptr_Function_float Function +// CHECK: OpFunctionCall %void %HullConst %param_var_edge %param_var_inside %param_var_myFloat +// CHECK: [[edges:%[0-9]+]] = OpLoad %_arr_float_uint_3 %param_var_edge +// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_0 +// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 0 +// CHECK: OpStore [[addr]] [[val]] +// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_1 +// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 1 +// CHECK: OpStore [[addr]] [[val]] +// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelOuter %uint_2 +// CHECK: [[val:%[0-9]+]] = OpCompositeExtract %float %66 2 +// CHECK: OpStore [[addr]] [[val]] +// CHECK: [[val:%[0-9]+]] = OpLoad %float %param_var_inside +// CHECK: [[addr:%[0-9]+]] = OpAccessChain %_ptr_Output_float %gl_TessLevelInner %uint_0 +// CHECK: OpStore [[addr]] [[val]] +// CHECK: [[val:%[0-9]+]] = OpLoad %float %param_var_myFloat +// CHECK: OpStore %out_var_MY_FLOAT [[val]] + +void HullConst (out float edge [3] : SV_TessFactor, out float inside : SV_InsideTessFactor, out float myFloat : MY_FLOAT) +{ + edge[0] = 2; + edge[1] = 2; + edge[2] = 2; + inside = 2; + myFloat = .2; +} + +[domain("tri")] +[partitioning("fractional_odd")] +[outputtopology("triangle_ccw")] +[patchconstantfunc("HullConst")] +[outputcontrolpoints(3)] +ControlPoint Hull (InputPatch v, uint id : SV_OutputControlPointID) { return v[id]; }