Skip to content

Commit

Permalink
[WebAssembly] getMemoryOpCost and getCastInstrCost (llvm#122896)
Browse files Browse the repository at this point in the history
Add inital implementations of these TTI methods for SIMD types. For
casts, The costing covers the free extensions provided by extmul_low as
well as extend_low. For memory operations we consider the use of
load32_zero and load64_zero, as well as full width v128 loads.
  • Loading branch information
sparker-arm authored Jan 31, 2025
1 parent 4cfbe55 commit 28d7880
Show file tree
Hide file tree
Showing 4 changed files with 693 additions and 2 deletions.
108 changes: 106 additions & 2 deletions llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
//===----------------------------------------------------------------------===//

#include "WebAssemblyTargetTransformInfo.h"

#include "llvm/CodeGen/CostTable.h"
using namespace llvm;

#define DEBUG_TYPE "wasmtti"
Expand Down Expand Up @@ -51,8 +53,7 @@ TypeSize WebAssemblyTTIImpl::getRegisterBitWidth(
InstructionCost WebAssemblyTTIImpl::getArithmeticInstrCost(
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
ArrayRef<const Value *> Args,
const Instruction *CxtI) {
ArrayRef<const Value *> Args, const Instruction *CxtI) {

InstructionCost Cost =
BasicTTIImplBase<WebAssemblyTTIImpl>::getArithmeticInstrCost(
Expand All @@ -78,6 +79,109 @@ InstructionCost WebAssemblyTTIImpl::getArithmeticInstrCost(
return Cost;
}

InstructionCost WebAssemblyTTIImpl::getCastInstrCost(
unsigned Opcode, Type *Dst, Type *Src, TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind, const Instruction *I) {
int ISD = TLI->InstructionOpcodeToISD(Opcode);
auto SrcTy = TLI->getValueType(DL, Src);
auto DstTy = TLI->getValueType(DL, Dst);

if (!SrcTy.isSimple() || !DstTy.isSimple()) {
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
}

if (!ST->hasSIMD128()) {
return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
}

auto DstVT = DstTy.getSimpleVT();
auto SrcVT = SrcTy.getSimpleVT();

if (I && I->hasOneUser()) {
auto *SingleUser = cast<Instruction>(*I->user_begin());
int UserISD = TLI->InstructionOpcodeToISD(SingleUser->getOpcode());

// extmul_low support
if (UserISD == ISD::MUL &&
(ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND)) {
// Free low extensions.
if ((SrcVT == MVT::v8i8 && DstVT == MVT::v8i16) ||
(SrcVT == MVT::v4i16 && DstVT == MVT::v4i32) ||
(SrcVT == MVT::v2i32 && DstVT == MVT::v2i64)) {
return 0;
}
// Will require an additional extlow operation for the intermediate
// i16/i32 value.
if ((SrcVT == MVT::v4i8 && DstVT == MVT::v4i32) ||
(SrcVT == MVT::v2i16 && DstVT == MVT::v2i64)) {
return 1;
}
}
}

// extend_low
static constexpr TypeConversionCostTblEntry ConversionTbl[] = {
{ISD::SIGN_EXTEND, MVT::v2i64, MVT::v2i32, 1},
{ISD::ZERO_EXTEND, MVT::v2i64, MVT::v2i32, 1},
{ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 1},
{ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i16, 1},
{ISD::SIGN_EXTEND, MVT::v8i16, MVT::v8i8, 1},
{ISD::ZERO_EXTEND, MVT::v8i16, MVT::v8i8, 1},
{ISD::SIGN_EXTEND, MVT::v2i64, MVT::v2i16, 2},
{ISD::ZERO_EXTEND, MVT::v2i64, MVT::v2i16, 2},
{ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i8, 2},
{ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i8, 2},
};

if (const auto *Entry =
ConvertCostTableLookup(ConversionTbl, ISD, DstVT, SrcVT)) {
return Entry->Cost;
}

return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
}

InstructionCost WebAssemblyTTIImpl::getMemoryOpCost(
unsigned Opcode, Type *Ty, MaybeAlign Alignment, unsigned AddressSpace,
TTI::TargetCostKind CostKind, TTI::OperandValueInfo OpInfo,
const Instruction *I) {
if (!ST->hasSIMD128() || !isa<FixedVectorType>(Ty)) {
return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
CostKind);
}

int ISD = TLI->InstructionOpcodeToISD(Opcode);
if (ISD != ISD::LOAD) {
return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
CostKind);
}

EVT VT = TLI->getValueType(DL, Ty, true);
// Type legalization can't handle structs
if (VT == MVT::Other)
return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
CostKind);

auto LT = getTypeLegalizationCost(Ty);
if (!LT.first.isValid())
return InstructionCost::getInvalid();

// 128-bit loads are a single instruction. 32-bit and 64-bit vector loads can
// be lowered to load32_zero and load64_zero respectively. Assume SIMD loads
// are twice as expensive as scalar.
unsigned width = VT.getSizeInBits();
switch (width) {
default:
break;
case 32:
case 64:
case 128:
return 2;
}

return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, CostKind);
}

InstructionCost
WebAssemblyTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);

InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
TTI::CastContextHint CCH,
TTI::TargetCostKind CostKind,
const Instruction *I = nullptr);
InstructionCost getMemoryOpCost(
unsigned Opcode, Type *Src, MaybeAlign Alignment, unsigned AddressSpace,
TTI::TargetCostKind CostKind,
TTI::OperandValueInfo OpInfo = {TTI::OK_AnyValue, TTI::OP_None},
const Instruction *I = nullptr);
using BaseT::getVectorInstrCost;
InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
TTI::TargetCostKind CostKind,
Expand Down
251 changes: 251 additions & 0 deletions llvm/test/CodeGen/WebAssembly/int-mac-reduction-loops.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
; RUN: opt -mattr=+simd128 -passes=loop-vectorize %s | llc -mtriple=wasm32 -mattr=+simd128 -verify-machineinstrs -o - | FileCheck %s

target triple = "wasm32"

define hidden i32 @i32_mac_s8(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
; CHECK-LABEL: i32_mac_s8:
; CHECK: v128.load32_zero 0:p2align=0
; CHECK: i16x8.extend_low_i8x16_s
; CHECK: v128.load32_zero 0:p2align=0
; CHECK: i16x8.extend_low_i8x16_s
; CHECK: i32x4.extmul_low_i16x8_s
; CHECK: i32x4.add
entry:
%cmp7.not = icmp eq i32 %N, 0
br i1 %cmp7.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body, %entry
%res.0.lcssa = phi i32 [ 0, %entry ], [ %add, %for.body ]
ret i32 %res.0.lcssa

for.body: ; preds = %entry, %for.body
%i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
%res.08 = phi i32 [ %add, %for.body ], [ 0, %entry ]
%arrayidx = getelementptr inbounds i8, ptr %a, i32 %i.09
%0 = load i8, ptr %arrayidx, align 1
%conv = sext i8 %0 to i32
%arrayidx1 = getelementptr inbounds i8, ptr %b, i32 %i.09
%1 = load i8, ptr %arrayidx1, align 1
%conv2 = sext i8 %1 to i32
%mul = mul nsw i32 %conv2, %conv
%add = add nsw i32 %mul, %res.08
%inc = add nuw i32 %i.09, 1
%exitcond.not = icmp eq i32 %inc, %N
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
}

define hidden i32 @i32_mac_s16(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
; CHECK-LABEL: i32_mac_s16:
; CHECK: i32x4.load16x4_s 0:p2align=1
; CHECK: i32x4.load16x4_s 0:p2align=1
; CHECK: i32x4.mul
; CHECK: i32x4.add
entry:
%cmp7.not = icmp eq i32 %N, 0
br i1 %cmp7.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body, %entry
%res.0.lcssa = phi i32 [ 0, %entry ], [ %add, %for.body ]
ret i32 %res.0.lcssa

for.body: ; preds = %entry, %for.body
%i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
%res.08 = phi i32 [ %add, %for.body ], [ 0, %entry ]
%arrayidx = getelementptr inbounds i16, ptr %a, i32 %i.09
%0 = load i16, ptr %arrayidx, align 2
%conv = sext i16 %0 to i32
%arrayidx1 = getelementptr inbounds i16, ptr %b, i32 %i.09
%1 = load i16, ptr %arrayidx1, align 2
%conv2 = sext i16 %1 to i32
%mul = mul nsw i32 %conv2, %conv
%add = add nsw i32 %mul, %res.08
%inc = add nuw i32 %i.09, 1
%exitcond.not = icmp eq i32 %inc, %N
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
}

define hidden i64 @i64_mac_s16(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
; CHECK-LABEL: i64_mac_s16:
; CHECK: v128.load32_zero 0:p2align=1
; CHECK: i32x4.extend_low_i16x8_s
; CHECK: v128.load32_zero 0:p2align=1
; CHECK: i32x4.extend_low_i16x8_s
; CHECK: i64x2.extmul_low_i32x4_s
; CHECK: i64x2.add
entry:
%cmp7.not = icmp eq i32 %N, 0
br i1 %cmp7.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body, %entry
%res.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
ret i64 %res.0.lcssa

for.body: ; preds = %entry, %for.body
%i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
%res.08 = phi i64 [ %add, %for.body ], [ 0, %entry ]
%arrayidx = getelementptr inbounds i16, ptr %a, i32 %i.09
%0 = load i16, ptr %arrayidx, align 2
%conv = sext i16 %0 to i64
%arrayidx1 = getelementptr inbounds i16, ptr %b, i32 %i.09
%1 = load i16, ptr %arrayidx1, align 2
%conv2 = sext i16 %1 to i64
%mul = mul nsw i64 %conv2, %conv
%add = add nsw i64 %mul, %res.08
%inc = add nuw i32 %i.09, 1
%exitcond.not = icmp eq i32 %inc, %N
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
}

define hidden i64 @i64_mac_s32(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
; CHECK-LABEL: i64_mac_s32:
; CHECK: v128.load64_zero 0:p2align=2
; CHECK: v128.load64_zero 0:p2align=2
; CHECK: i32x4.mul
; CHECK: i64x2.extend_low_i32x4_s
; CHECK: i64x2.add
entry:
%cmp6.not = icmp eq i32 %N, 0
br i1 %cmp6.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body, %entry
%res.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
ret i64 %res.0.lcssa

for.body: ; preds = %entry, %for.body
%i.08 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
%res.07 = phi i64 [ %add, %for.body ], [ 0, %entry ]
%arrayidx = getelementptr inbounds i32, ptr %a, i32 %i.08
%0 = load i32, ptr %arrayidx, align 4
%arrayidx1 = getelementptr inbounds i32, ptr %b, i32 %i.08
%1 = load i32, ptr %arrayidx1, align 4
%mul = mul i32 %1, %0
%conv = sext i32 %mul to i64
%add = add i64 %res.07, %conv
%inc = add nuw i32 %i.08, 1
%exitcond.not = icmp eq i32 %inc, %N
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
}

define hidden i32 @i32_mac_u8(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
; CHECK-LABEL: i32_mac_u8:
; CHECK: v128.load32_zero 0:p2align=0
; CHECK: i16x8.extend_low_i8x16_u
; CHECK: v128.load32_zero 0:p2align=0
; CHECK: i16x8.extend_low_i8x16_u
; CHECK: i32x4.extmul_low_i16x8_u
; CHECK: i32x4.add
entry:
%cmp7.not = icmp eq i32 %N, 0
br i1 %cmp7.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body, %entry
%res.0.lcssa = phi i32 [ 0, %entry ], [ %add, %for.body ]
ret i32 %res.0.lcssa

for.body: ; preds = %entry, %for.body
%i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
%res.08 = phi i32 [ %add, %for.body ], [ 0, %entry ]
%arrayidx = getelementptr inbounds i8, ptr %a, i32 %i.09
%0 = load i8, ptr %arrayidx, align 1
%conv = zext i8 %0 to i32
%arrayidx1 = getelementptr inbounds i8, ptr %b, i32 %i.09
%1 = load i8, ptr %arrayidx1, align 1
%conv2 = zext i8 %1 to i32
%mul = mul nuw nsw i32 %conv2, %conv
%add = add i32 %mul, %res.08
%inc = add nuw i32 %i.09, 1
%exitcond.not = icmp eq i32 %inc, %N
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
}

define hidden i32 @i32_mac_u16(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
; CHECK-LABEL: i32_mac_u16:
; CHECK: i32x4.load16x4_u 0:p2align=1
; CHECK: i32x4.load16x4_u 0:p2align=1
; CHECK: i32x4.mul
; CHECK: i32x4.add
entry:
%cmp7.not = icmp eq i32 %N, 0
br i1 %cmp7.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body, %entry
%res.0.lcssa = phi i32 [ 0, %entry ], [ %add, %for.body ]
ret i32 %res.0.lcssa

for.body: ; preds = %entry, %for.body
%i.09 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
%res.08 = phi i32 [ %add, %for.body ], [ 0, %entry ]
%arrayidx = getelementptr inbounds i16, ptr %a, i32 %i.09
%0 = load i16, ptr %arrayidx, align 2
%conv = zext i16 %0 to i32
%arrayidx1 = getelementptr inbounds i16, ptr %b, i32 %i.09
%1 = load i16, ptr %arrayidx1, align 2
%conv2 = zext i16 %1 to i32
%mul = mul nuw nsw i32 %conv2, %conv
%add = add i32 %mul, %res.08
%inc = add nuw i32 %i.09, 1
%exitcond.not = icmp eq i32 %inc, %N
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
}

define hidden i64 @i64_mac_u16(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
; CHECK-LABEL: i64_mac_u16:
; CHECK: v128.load32_zero 0:p2align=1
; CHECK: i32x4.extend_low_i16x8_u
; CHECK: v128.load32_zero 0:p2align=1
; CHECK: i32x4.extend_low_i16x8_u
; CHECK: i64x2.extmul_low_i32x4_u
; CHECK: i64x2.add
entry:
%cmp8.not = icmp eq i32 %N, 0
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body, %entry
%res.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
ret i64 %res.0.lcssa

for.body: ; preds = %entry, %for.body
%i.010 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
%res.09 = phi i64 [ %add, %for.body ], [ 0, %entry ]
%arrayidx = getelementptr inbounds i16, ptr %a, i32 %i.010
%0 = load i16, ptr %arrayidx, align 2
%conv = zext i16 %0 to i64
%arrayidx1 = getelementptr inbounds i16, ptr %b, i32 %i.010
%1 = load i16, ptr %arrayidx1, align 2
%conv2 = zext i16 %1 to i64
%mul = mul nuw nsw i64 %conv2, %conv
%add = add i64 %mul, %res.09
%inc = add nuw i32 %i.010, 1
%exitcond.not = icmp eq i32 %inc, %N
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
}

define hidden i64 @i64_mac_u32(ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %N) {
; CHECK-LABEL: i64_mac_u32:
; CHECK: v128.load64_zero 0:p2align=2
; CHECK: v128.load64_zero 0:p2align=2
; CHECK: i32x4.mul
; CHECK: i64x2.extend_low_i32x4_u
; CHECK: i64x2.add
entry:
%cmp6.not = icmp eq i32 %N, 0
br i1 %cmp6.not, label %for.cond.cleanup, label %for.body

for.cond.cleanup: ; preds = %for.body, %entry
%res.0.lcssa = phi i64 [ 0, %entry ], [ %add, %for.body ]
ret i64 %res.0.lcssa

for.body: ; preds = %entry, %for.body
%i.08 = phi i32 [ %inc, %for.body ], [ 0, %entry ]
%res.07 = phi i64 [ %add, %for.body ], [ 0, %entry ]
%arrayidx = getelementptr inbounds i32, ptr %a, i32 %i.08
%0 = load i32, ptr %arrayidx, align 4
%arrayidx1 = getelementptr inbounds i32, ptr %b, i32 %i.08
%1 = load i32, ptr %arrayidx1, align 4
%mul = mul i32 %1, %0
%conv = zext i32 %mul to i64
%add = add i64 %res.07, %conv
%inc = add nuw i32 %i.08, 1
%exitcond.not = icmp eq i32 %inc, %N
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
}
Loading

0 comments on commit 28d7880

Please sign in to comment.