Author: Florian Hahn
Date: 2026-05-26T18:16:33+01:00
New Revision: 80a5207d940bcffa4f4339111dca65a4fcca8926

URL: 
https://github.com/llvm/llvm-project/commit/80a5207d940bcffa4f4339111dca65a4fcca8926
DIFF: 
https://github.com/llvm/llvm-project/commit/80a5207d940bcffa4f4339111dca65a4fcca8926.diff

LOG: [VPlan] Thread scalar types through VPInstruction and VPPhi. (NFC) 
(#199378)

Update VPInstruction and VPPhi to populate VPSingleDefValue's scalar
type. For most opcodes, the scalar type is determine from the operands,
via computeScalarTypeForInstruction, which roughly matches to removed
inference code. For some opcodes, like FirstActiveLane, the type must be
provided explicitly.

PR: https://github.com/llvm/llvm-project/pull/199378

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
    llvm/lib/Transforms/Vectorize/VPlan.h
    llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
    llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
    llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
    llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
    llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp
    llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
    llvm/unittests/Transforms/Vectorize/VPlanTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h 
b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
index 7bc9ca1d04cfa..673868908d899 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
@@ -196,9 +196,10 @@ class VPBuilder {
                               const VPIRFlags &Flags = {},
                               const VPIRMetadata &MD = {},
                               DebugLoc DL = DebugLoc::getUnknown(),
-                              const Twine &Name = "") {
+                              const Twine &Name = "",
+                              Type *ResultTy = nullptr) {
     VPInstruction *NewVPInst = tryInsertInstruction(
-        new VPInstruction(Opcode, Operands, Flags, MD, DL, Name));
+        new VPInstruction(Opcode, Operands, Flags, MD, DL, Name, ResultTy));
     NewVPInst->setUnderlyingValue(Inst);
     return NewVPInst;
   }
@@ -225,15 +226,23 @@ class VPBuilder {
   VPInstruction *createFirstActiveLane(ArrayRef<VPValue *> Masks,
                                        DebugLoc DL = DebugLoc::getUnknown(),
                                        const Twine &Name = "") {
+    // Assume that the maximum possible number of elements in a vector fits
+    // within the index type for the default address space.
+    VPlan &Plan = getPlan();
+    Type *IndexTy = Plan.getDataLayout().getIndexType(Plan.getContext(), 0);
     return tryInsertInstruction(new VPInstruction(
-        VPInstruction::FirstActiveLane, Masks, {}, {}, DL, Name));
+        VPInstruction::FirstActiveLane, Masks, {}, {}, DL, Name, IndexTy));
   }
 
   VPInstruction *createLastActiveLane(ArrayRef<VPValue *> Masks,
                                       DebugLoc DL = DebugLoc::getUnknown(),
                                       const Twine &Name = "") {
-    return tryInsertInstruction(new 
VPInstruction(VPInstruction::LastActiveLane,
-                                                  Masks, {}, {}, DL, Name));
+    // Assume that the maximum possible number of elements in a vector fits
+    // within the index type for the default address space.
+    VPlan &Plan = getPlan();
+    Type *IndexTy = Plan.getDataLayout().getIndexType(Plan.getContext(), 0);
+    return tryInsertInstruction(new VPInstruction(
+        VPInstruction::LastActiveLane, Masks, {}, {}, DL, Name, IndexTy));
   }
 
   VPInstruction *createOverflowingOp(
@@ -358,8 +367,10 @@ class VPBuilder {
 
   VPPhi *createScalarPhi(ArrayRef<VPValue *> IncomingValues,
                          DebugLoc DL = DebugLoc::getUnknown(),
-                         const Twine &Name = "", const VPIRFlags &Flags = {}) {
-    return tryInsertInstruction(new VPPhi(IncomingValues, Flags, DL, Name));
+                         const Twine &Name = "", const VPIRFlags &Flags = {},
+                         Type *ResultTy = nullptr) {
+    return tryInsertInstruction(
+        new VPPhi(IncomingValues, Flags, DL, Name, ResultTy));
   }
 
   VPWidenPHIRecipe *createWidenPhi(ArrayRef<VPValue *> IncomingValues,

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.h 
b/llvm/lib/Transforms/Vectorize/VPlan.h
index ea9d80a6afc22..114b87a86410e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -604,6 +604,10 @@ class LLVM_ABI_FOR_TEST VPRecipeBase
 /// types.
 LLVM_ABI Type *getScalarTypeOrInfer(VPValue *V);
 
+/// Compute the scalar result type for an IR \p Opcode given \p Operands.
+LLVM_ABI Type *computeScalarTypeForInstruction(unsigned Opcode,
+                                               ArrayRef<VPValue *> Operands);
+
 /// VPSingleDefRecipe is a base class for recipes that model a sequence of one
 /// or more output IR that define a single result VPValue. Note that
 /// VPSingleDefRecipe must inherit from VPRecipeBase before VPSingleDefValue.
@@ -1393,15 +1397,19 @@ class LLVM_ABI_FOR_TEST VPInstruction : public 
VPRecipeWithIRFlags,
 public:
   VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
                 const VPIRFlags &Flags = {}, const VPIRMetadata &MD = {},
-                DebugLoc DL = DebugLoc::getUnknown(), const Twine &Name = "");
+                DebugLoc DL = DebugLoc::getUnknown(), const Twine &Name = "",
+                Type *ResultTy = nullptr);
 
   VP_CLASSOF_IMPL(VPRecipeBase::VPInstructionSC)
 
-  VPInstruction *clone() override { return cloneWithOperands(operands()); }
+  VPInstruction *clone() override {
+    return cloneWithOperands(operands(), getScalarType());
+  }
 
-  VPInstruction *cloneWithOperands(ArrayRef<VPValue *> NewOperands) {
+  VPInstruction *cloneWithOperands(ArrayRef<VPValue *> NewOperands,
+                                   Type *ResultTy = nullptr) {
     auto *New = new VPInstruction(Opcode, NewOperands, *this, *this,
-                                  getDebugLoc(), Name);
+                                  getDebugLoc(), Name, ResultTy);
     if (getUnderlyingValue())
       New->setUnderlyingValue(getUnderlyingInstr());
     return New;
@@ -1521,18 +1529,15 @@ class LLVM_ABI_FOR_TEST VPInstruction : public 
VPRecipeWithIRFlags,
 /// directly determine the result type. Note that there is no separate recipe 
ID
 /// for VPInstructionWithType; it shares the same ID as VPInstruction and is
 /// distinguished purely by the opcode.
+/// TODO: Merge with VPInstruction, now that VPRecipeValue provides the type.
 class VPInstructionWithType : public VPInstruction {
-  /// Scalar result type produced by the recipe.
-  Type *ResultTy;
-
 public:
   VPInstructionWithType(unsigned Opcode, ArrayRef<VPValue *> Operands,
                         Type *ResultTy, const VPIRFlags &Flags = {},
                         const VPIRMetadata &Metadata = {},
                         DebugLoc DL = DebugLoc::getUnknown(),
                         const Twine &Name = "")
-      : VPInstruction(Opcode, Operands, Flags, Metadata, DL, Name),
-        ResultTy(ResultTy) {}
+      : VPInstruction(Opcode, Operands, Flags, Metadata, DL, Name, ResultTy) {}
 
   static inline bool classof(const VPRecipeBase *R) {
     // VPInstructionWithType are VPInstructions with specific opcodes requiring
@@ -1575,7 +1580,7 @@ class VPInstructionWithType : public VPInstruction {
     return 0;
   }
 
-  Type *getResultType() const { return ResultTy; }
+  Type *getResultType() const { return getScalarType(); }
 
 protected:
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -1652,8 +1657,9 @@ class VPPhiAccessors {
 
 struct LLVM_ABI_FOR_TEST VPPhi : public VPInstruction, public VPPhiAccessors {
   VPPhi(ArrayRef<VPValue *> Operands, const VPIRFlags &Flags, DebugLoc DL,
-        const Twine &Name = "")
-      : VPInstruction(Instruction::PHI, Operands, Flags, {}, DL, Name) {}
+        const Twine &Name = "", Type *ResultTy = nullptr)
+      : VPInstruction(Instruction::PHI, Operands, Flags, {}, DL, Name,
+                      ResultTy) {}
 
   static inline bool classof(const VPUser *U) {
     auto *VPI = dyn_cast<VPInstruction>(U);

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp 
b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 40cce07557ab5..58ac22b955fb8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -35,122 +35,6 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const 
VPBlendRecipe *R) {
   return ResTy;
 }
 
-Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
-  // Set the result type from the first operand, check if the types for all
-  // other operands match and cache them.
-  auto SetResultTyFromOp = [this, R]() {
-    Type *ResTy = inferScalarType(R->getOperand(0));
-    unsigned NumOperands = R->getNumOperandsWithoutMask();
-    for (unsigned Op = 1; Op != NumOperands; ++Op) {
-      VPValue *OtherV = R->getOperand(Op);
-      assert(inferScalarType(OtherV) == ResTy &&
-             "
diff erent types inferred for 
diff erent operands");
-      CachedTypes[OtherV] = ResTy;
-    }
-    return ResTy;
-  };
-
-  unsigned Opcode = R->getOpcode();
-  if (Instruction::isBinaryOp(Opcode) || Instruction::isUnaryOp(Opcode))
-    return SetResultTyFromOp();
-
-  switch (Opcode) {
-  case Instruction::PHI:
-    for (VPValue *Op : R->operands()) {
-      if (auto *VIR = dyn_cast<VPIRValue>(Op))
-        return VIR->getType();
-      if (auto *Ty = CachedTypes.lookup(Op))
-        return Ty;
-    }
-  LLVM_FALLTHROUGH;
-  case Instruction::ExtractElement:
-  case Instruction::InsertElement:
-  case Instruction::Freeze:
-  case VPInstruction::Broadcast:
-  case VPInstruction::ComputeReductionResult:
-  case VPInstruction::ExitingIVValue:
-  case VPInstruction::ExtractLastLane:
-  case VPInstruction::ExtractPenultimateElement:
-  case VPInstruction::ExtractLastPart:
-  case VPInstruction::ExtractLastActive:
-  case VPInstruction::PtrAdd:
-  case VPInstruction::WidePtrAdd:
-  case VPInstruction::ReductionStartVector:
-  case VPInstruction::ResumeForEpilogue:
-  case VPInstruction::Reverse:
-    return inferScalarType(R->getOperand(0));
-  case Instruction::Select: {
-    Type *ResTy = inferScalarType(R->getOperand(1));
-    VPValue *OtherV = R->getOperand(2);
-    assert(inferScalarType(OtherV) == ResTy &&
-           "
diff erent types inferred for 
diff erent operands");
-    CachedTypes[OtherV] = ResTy;
-    return ResTy;
-  }
-  case Instruction::ICmp:
-  case Instruction::FCmp:
-  case VPInstruction::ActiveLaneMask:
-    assert(inferScalarType(R->getOperand(0)) ==
-               inferScalarType(R->getOperand(1)) &&
-           "
diff erent types inferred for 
diff erent operands");
-    return IntegerType::get(Ctx, 1);
-  case VPInstruction::ExplicitVectorLength:
-    return Type::getIntNTy(Ctx, 32);
-  case VPInstruction::FirstOrderRecurrenceSplice:
-  case VPInstruction::Not:
-  case VPInstruction::CalculateTripCountMinusVF:
-  case VPInstruction::CanonicalIVIncrementForPart:
-  case VPInstruction::AnyOf:
-  case VPInstruction::BuildStructVector:
-  case VPInstruction::BuildVector:
-  case VPInstruction::Unpack:
-    return SetResultTyFromOp();
-  case VPInstruction::ExtractLane:
-    return inferScalarType(R->getOperand(1));
-  case VPInstruction::FirstActiveLane:
-  case VPInstruction::LastActiveLane:
-    // Assume that the maximum possible number of elements in a vector fits
-    // within the index type for the default address space.
-    return DL.getIndexType(Ctx, 0);
-  case VPInstruction::LogicalAnd:
-  case VPInstruction::LogicalOr:
-    assert(inferScalarType(R->getOperand(0))->isIntegerTy(1) &&
-           inferScalarType(R->getOperand(1))->isIntegerTy(1) &&
-           "LogicalAnd/Or operands should be bool");
-    return IntegerType::get(Ctx, 1);
-  case VPInstruction::MaskedCond:
-    assert(inferScalarType(R->getOperand(0))->isIntegerTy(1));
-    return IntegerType::get(Ctx, 1);
-  case VPInstruction::BranchOnCond:
-  case VPInstruction::BranchOnTwoConds:
-  case VPInstruction::BranchOnCount:
-  case Instruction::Store:
-  case Instruction::Switch:
-    return Type::getVoidTy(Ctx);
-  case Instruction::Load:
-    return cast<LoadInst>(R->getUnderlyingValue())->getType();
-  case Instruction::Alloca:
-    return cast<AllocaInst>(R->getUnderlyingValue())->getType();
-  case Instruction::Call: {
-    unsigned CallIdx = R->getNumOperandsWithoutMask() - 1;
-    return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
-        ->getReturnType();
-  }
-  case Instruction::GetElementPtr:
-    return inferScalarType(R->getOperand(0));
-  case Instruction::ExtractValue:
-    return cast<ExtractValueInst>(R->getUnderlyingValue())->getType();
-  default:
-    break;
-  }
-  // Type inference not implemented for opcode.
-  LLVM_DEBUG({
-    dbgs() << "LV: Found unhandled opcode for: ";
-    R->getVPSingleValue()->dump();
-  });
-  llvm_unreachable("Unhandled opcode!");
-}
-
 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
   unsigned Opcode = R->getOpcode();
   if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) ||
@@ -260,7 +144,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
           VPScalarIVStepsRecipe, VPWidenCanonicalIVRecipe, VPWidenCastRecipe,
           VPWidenIntrinsicRecipe, VPWidenGEPRecipe, VPVectorPointerRecipe,
           VPVectorEndPointerRecipe, VPWidenCallRecipe, VPWidenLoadRecipe,
-          VPWidenLoadEVLRecipe, VPDerivedIVRecipe, VPHeaderPHIRecipe>(V)) {
+          VPWidenLoadEVLRecipe, VPDerivedIVRecipe, VPHeaderPHIRecipe,
+          VPInstruction>(V)) {
     Type *Ty = V->getScalarType();
     assert(Ty && "Scalar type must be set by recipe construction");
     return Ty;
@@ -268,10 +153,7 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
 
   Type *ResultTy =
       TypeSwitch<const VPRecipeBase *, Type *>(V->getDefiningRecipe())
-          // VPInstructionWithType must be handled before VPInstruction.
-          .Case<VPInstructionWithType>(
-              [](const auto *R) { return R->getResultType(); })
-          .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, 
VPReplicateRecipe>(
+          .Case<VPBlendRecipe, VPWidenRecipe, VPReplicateRecipe>(
               [this](const auto *R) { return inferScalarTypeForRecipe(R); })
           .Case([this](const VPReductionRecipe *R) {
             return inferScalarType(R->getChainOp());

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h 
b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
index 8832c85b1dd02..8f4bb5219d0ca 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h
@@ -22,7 +22,6 @@ namespace llvm {
 class LLVMContext;
 class VPValue;
 class VPBlendRecipe;
-class VPInstruction;
 class VPWidenRecipe;
 class VPReplicateRecipe;
 class VPRecipeBase;
@@ -48,7 +47,6 @@ class VPTypeAnalysis {
   const DataLayout &DL;
 
   Type *inferScalarTypeForRecipe(const VPBlendRecipe *R);
-  Type *inferScalarTypeForRecipe(const VPInstruction *R);
   Type *inferScalarTypeForRecipe(const VPWidenRecipe *R);
   Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R);
 

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp 
b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 01aa0c16799a8..bc08873dd665a 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -225,8 +225,8 @@ void 
PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB,
       // Phi node's operands may not have been visited at this point. We create
       // an empty VPInstruction that we will fix once the whole plain CFG has
       // been built.
-      NewR =
-          VPIRBuilder.createScalarPhi({}, Phi->getDebugLoc(), "vec.phi", *Phi);
+      NewR = VPIRBuilder.createScalarPhi({}, Phi->getDebugLoc(), "vec.phi",
+                                         *Phi, Phi->getType());
       NewR->setUnderlyingValue(Phi);
       if (isHeaderBB(Phi->getParent(), LI->getLoopFor(Phi->getParent()))) {
         // Header phis need to be fixed after the VPBB for the latch has been
@@ -275,9 +275,9 @@ void 
PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB,
       } else {
         // Build VPInstruction for any arbitrary Instruction without specific
         // representation in VPlan.
-        NewR =
-            VPIRBuilder.createNaryOp(Inst->getOpcode(), VPOperands, Inst,
-                                     VPIRFlags(*Inst), MD, 
Inst->getDebugLoc());
+        NewR = VPIRBuilder.createNaryOp(
+            Inst->getOpcode(), VPOperands, Inst, VPIRFlags(*Inst), MD,
+            Inst->getDebugLoc(), "", Inst->getType());
       }
     }
 

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp 
b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index ab5713a2dde11..b06ebadc479d5 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -437,10 +437,111 @@ VPExpandSCEVRecipe::VPExpandSCEVRecipe(const SCEV *Expr)
     : VPSingleDefRecipe(VPRecipeBase::VPExpandSCEVSC, {}, Expr->getType()),
       Expr(Expr) {}
 
+/// For call VPInstruction operands, return the operand index of the called
+/// function. The function is either the last operand (for unmasked calls) or
+/// the second-to-last operand (for masked calls).
+static unsigned getCalledFnOperandIndex(ArrayRef<VPValue *> Operands) {
+  unsigned NumOps = Operands.size();
+  auto *LastOp = dyn_cast<VPIRValue>(Operands[NumOps - 1]);
+  if (LastOp && isa<Function>(LastOp->getValue()))
+    return NumOps - 1;
+  assert(isa<Function>(cast<VPIRValue>(Operands[NumOps - 2])->getValue()) &&
+         "expected function operand");
+  return NumOps - 2;
+}
+
+/// For call VPInstruction operands, return the called function.
+static Function *getCalledFunction(ArrayRef<VPValue *> Operands) {
+  unsigned Idx = getCalledFnOperandIndex(Operands);
+  return cast<Function>(cast<VPIRValue>(Operands[Idx])->getValue());
+}
+
+Type *llvm::computeScalarTypeForInstruction(unsigned Opcode,
+                                            ArrayRef<VPValue *> Operands) {
+  assert(!Operands.empty() &&
+         "zero-operand VPInstruction opcodes must pass explicit ResultTy");
+  // Assert operand \p Idx (if present and typed) has type \p ExpectedTy.
+  [[maybe_unused]] auto AssertOperandType = [&Operands](unsigned Idx,
+                                                        Type *ExpectedTy) {
+    if (!ExpectedTy || Operands.size() <= Idx)
+      return;
+    Type *OpTy = getScalarTypeOrInfer(Operands[Idx]);
+    assert((!OpTy || OpTy == ExpectedTy) &&
+           "
diff erent types inferred for 
diff erent operands");
+  };
+
+  Type *Op0Ty = getScalarTypeOrInfer(Operands[0]);
+  LLVMContext &Ctx = Op0Ty->getContext();
+  switch (Opcode) {
+  case VPInstruction::BranchOnCond:
+  case VPInstruction::BranchOnTwoConds:
+  case VPInstruction::BranchOnCount:
+  case Instruction::Store:
+  case Instruction::Switch:
+    return Type::getVoidTy(Ctx);
+  case Instruction::ICmp:
+  case Instruction::FCmp:
+  case VPInstruction::ActiveLaneMask:
+    AssertOperandType(1, Op0Ty);
+    return IntegerType::get(Ctx, 1);
+  case VPInstruction::LogicalAnd:
+  case VPInstruction::LogicalOr:
+  case VPInstruction::MaskedCond:
+    assert((!Op0Ty || Op0Ty->isIntegerTy(1)) && "expected bool operand");
+    AssertOperandType(1, Op0Ty);
+    return IntegerType::get(Ctx, 1);
+  case VPInstruction::ExplicitVectorLength:
+    return IntegerType::get(Ctx, 32);
+  case Instruction::Select: {
+    Type *Op1Ty = getScalarTypeOrInfer(Operands[1]);
+    AssertOperandType(2, Op1Ty);
+    return Op1Ty;
+  }
+  case VPInstruction::ExtractLane: {
+    assert(Operands.size() >= 2 && "ExtractLane requires a lane operand and "
+                                   "at least one source vector operand");
+    Type *Op1Ty = getScalarTypeOrInfer(Operands[1]);
+    for (unsigned Idx = 2; Idx != Operands.size(); ++Idx)
+      AssertOperandType(Idx, Op1Ty);
+    return Op1Ty;
+  }
+  case Instruction::ExtractValue:
+  case VPInstruction::FirstActiveLane:
+  case VPInstruction::LastActiveLane:
+  case Instruction::Load:
+  case Instruction::Alloca:
+    llvm_unreachable("type must be passed explicitly");
+  case Instruction::Call:
+    return getCalledFunction(Operands)->getReturnType();
+  default:
+    break;
+  }
+
+  // Opcodes that require all operands to share the same scalar type as the
+  // result.
+  bool AllOperandsSameType =
+      Instruction::isBinaryOp(Opcode) ||
+      is_contained({VPInstruction::FirstOrderRecurrenceSplice,
+                    VPInstruction::CalculateTripCountMinusVF,
+                    VPInstruction::CanonicalIVIncrementForPart,
+                    VPInstruction::AnyOf, VPInstruction::BuildVector,
+                    VPInstruction::BuildStructVector},
+                   Opcode);
+  if (AllOperandsSameType)
+    for (unsigned Idx = 1; Idx != Operands.size(); ++Idx)
+      AssertOperandType(Idx, Op0Ty);
+
+  return Op0Ty;
+}
+
 VPInstruction::VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands,
                              const VPIRFlags &Flags, const VPIRMetadata &MD,
-                             DebugLoc DL, const Twine &Name)
-    : VPRecipeWithIRFlags(VPRecipeBase::VPInstructionSC, Operands, Flags, DL),
+                             DebugLoc DL, const Twine &Name, Type *ResultTy)
+    : VPRecipeWithIRFlags(
+          VPRecipeBase::VPInstructionSC, Operands,
+          ResultTy ? ResultTy
+                   : computeScalarTypeForInstruction(Opcode, Operands),
+          Flags, DL),
       VPIRMetadata(MD), Opcode(Opcode), Name(Name.str()) {
   assert(flagsValidForOpcode(getOpcode()) &&
          "Set flags not supported for the provided opcode");
@@ -452,27 +553,6 @@ VPInstruction::VPInstruction(unsigned Opcode, 
ArrayRef<VPValue *> Operands,
          "number of operands does not match opcode");
 }
 
-/// For call VPInstructions, return the operand index of the called function.
-/// The function is either the last operand (for unmasked calls) or the
-/// second-to-last operand (for masked calls).
-static unsigned getCalledFnOperandIndex(const VPInstruction &VPI) {
-  assert(VPI.getOpcode() == Instruction::Call && "must be a call");
-  unsigned NumOps = VPI.getNumOperands();
-  auto *LastOp = dyn_cast<VPIRValue>(VPI.getOperand(NumOps - 1));
-  if (LastOp && isa<Function>(LastOp->getValue()))
-    return NumOps - 1;
-  assert(
-      isa<Function>(cast<VPIRValue>(VPI.getOperand(NumOps - 2))->getValue()) &&
-      "expected function operand");
-  return NumOps - 2;
-}
-
-/// For call VPInstructions, return the called function.
-static Function *getCalledFunction(const VPInstruction &VPI) {
-  unsigned Idx = getCalledFnOperandIndex(VPI);
-  return cast<Function>(cast<VPIRValue>(VPI.getOperand(Idx))->getValue());
-}
-
 unsigned VPInstruction::getNumOperandsForOpcode() const {
   if (Instruction::isUnaryOp(Opcode) || Instruction::isCast(Opcode))
     return 1;
@@ -521,7 +601,8 @@ unsigned VPInstruction::getNumOperandsForOpcode() const {
   case VPInstruction::ReductionStartVector:
     return 3;
   case Instruction::Call:
-    return getCalledFnOperandIndex(*this) + 1;
+    return getCalledFnOperandIndex(ArrayRef<VPValue *>(op_begin(), op_end())) +
+           1;
   case Instruction::GetElementPtr:
   case Instruction::PHI:
   case Instruction::Switch:
@@ -1420,7 +1501,8 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() 
const {
   case VPInstruction::Unpack:
     return false;
   case Instruction::Call:
-    return !getCalledFunction(*this)->doesNotAccessMemory();
+    return !getCalledFunction(ArrayRef<VPValue *>(op_begin(), op_end()))
+                ->doesNotAccessMemory();
   default:
     return true;
   }
@@ -1617,6 +1699,7 @@ void VPInstruction::printRecipe(raw_ostream &O, const 
Twine &Indent,
 #endif
 
 void VPInstructionWithType::execute(VPTransformState &State) {
+  Type *ResultTy = getResultType();
   if (Instruction::isCast(getOpcode())) {
     Value *Op = State.get(getOperand(0), VPLane(0));
     Value *Cast = State.Builder.CreateCast(Instruction::CastOps(getOpcode()),
@@ -1653,6 +1736,7 @@ void VPInstructionWithType::printRecipe(raw_ostream &O, 
const Twine &Indent,
   printAsOperand(O, SlotTracker);
   O << " = ";
 
+  Type *ResultTy = getResultType();
   switch (getOpcode()) {
   case VPInstruction::WideIVStep:
     O << "wide-iv-step ";
@@ -1678,8 +1762,7 @@ void VPInstructionWithType::printRecipe(raw_ostream &O, 
const Twine &Indent,
 #endif
 
 void VPPhi::execute(VPTransformState &State) {
-  PHINode *NewPhi = State.Builder.CreatePHI(
-      State.TypeAnalysis.inferScalarType(this), 2, getName());
+  PHINode *NewPhi = State.Builder.CreatePHI(getScalarType(), 2, getName());
   unsigned NumIncoming = getNumIncoming();
   // Detect header phis: the parent block dominates its second incoming block
   // (the latch). Those IR incoming values have not been generated yet and need
@@ -3117,7 +3200,8 @@ VPExpressionRecipe::VPExpressionRecipe(
       if (Def && ExpressionRecipesAsSetOfUsers.contains(Def))
         continue;
       addOperand(Op);
-      LiveInPlaceholders.push_back(new VPSymbolicValue(nullptr));
+      LiveInPlaceholders.push_back(
+          new VPSymbolicValue(getScalarTypeOrInfer(Op)));
     }
   }
 

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp 
b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 42a410c916471..fe6f036c63a92 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -1032,9 +1032,8 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan 
&Plan,
 
   DebugLoc DL = ExtractR->getDebugLoc();
   VPValue *FirstActiveLane = B.createFirstActiveLane(Mask, DL);
-  Type *FirstActiveLaneType = TypeInfo.inferScalarType(FirstActiveLane);
-  FirstActiveLane = B.createScalarZExtOrTrunc(FirstActiveLane, CanonicalIVType,
-                                              FirstActiveLaneType, DL);
+  FirstActiveLane = B.createScalarZExtOrTrunc(
+      FirstActiveLane, CanonicalIVType, FirstActiveLane->getScalarType(), DL);
   VPValue *EndValue = B.createAdd(CanonicalIV, FirstActiveLane, DL);
 
   // `getOptimizableIVOf()` always returns the pre-incremented IV, so if it
@@ -4093,7 +4092,7 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan 
&Plan) {
 
         // Subtract 1 to get the last active lane.
         VPValue *One =
-            Plan.getConstantInt(TypeInfo.inferScalarType(FirstInactiveLane), 
1);
+            Plan.getConstantInt(FirstInactiveLane->getScalarType(), 1);
         VPValue *LastLane =
             Builder.createSub(FirstInactiveLane, One,
                               LastActiveL->getDebugLoc(), "last.active.lane");

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp 
b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp
index 5071aad2d0287..d129842d11be1 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp
@@ -258,6 +258,20 @@ bool VPlanVerifier::verifyRecipeTypes(const VPRecipeBase 
&R) const {
   case VPRecipeBase::VPFirstOrderRecurrencePHISC:
     return CheckOperandTypes() &&
            CheckScalarType(getScalarTypeOrInfer(R.getOperand(0)));
+  case VPRecipeBase::VPInstructionSC: {
+    auto *VPI = cast<VPInstruction>(&R);
+    if (isa<VPInstructionWithType>(VPI) ||
+        is_contained(ArrayRef<unsigned>{Instruction::ExtractValue,
+                                        VPInstruction::FirstActiveLane,
+                                        VPInstruction::LastActiveLane,
+                                        Instruction::Load, Instruction::Alloca,
+                                        Instruction::Call},
+                     VPI->getOpcode()))
+      return true;
+    SmallVector<VPValue *, 4> Ops(VPI->operandsWithoutMask());
+    return CheckScalarType(
+        computeScalarTypeForInstruction(VPI->getOpcode(), Ops));
+  }
   default:
     return true;
   }

diff  --git a/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp 
b/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
index f66f327a32d95..de49dd6b57e61 100644
--- a/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp
@@ -236,25 +236,32 @@ TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
     //  VPBB2
     //
     VPlan &Plan = getPlan();
+    IntegerType *Int32 = IntegerType::get(C, 32);
     VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("R1BB1");
-    VPInstruction *R1BB1I = new VPInstruction(VPInstruction::VScale, {});
+    VPInstruction *R1BB1I =
+        new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
     R1BB1->appendRecipe(R1BB1I);
     VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("R1BB2");
-    VPInstruction *R1BB2I = new VPInstruction(VPInstruction::VScale, {});
+    VPInstruction *R1BB2I =
+        new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
     R1BB2->appendRecipe(R1BB2I);
     VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("R1BB3");
-    VPInstruction *R1BB3I = new VPInstruction(VPInstruction::VScale, {});
+    VPInstruction *R1BB3I =
+        new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
     R1BB3->appendRecipe(R1BB3I);
     VPRegionBlock *R1 = Plan.createReplicateRegion(R1BB1, R1BB3, "R1");
 
     VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("R2BB1");
-    VPInstruction *R2BB1I = new VPInstruction(VPInstruction::VScale, {});
+    VPInstruction *R2BB1I =
+        new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
     R2BB1->appendRecipe(R2BB1I);
     VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("R2BB2");
-    VPInstruction *R2BB2I = new VPInstruction(VPInstruction::VScale, {});
+    VPInstruction *R2BB2I =
+        new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
     R2BB2->appendRecipe(R2BB2I);
     VPBasicBlock *R2BB3 = Plan.createVPBasicBlock("R2BB3");
-    VPInstruction *R2BB3I = new VPInstruction(VPInstruction::VScale, {});
+    VPInstruction *R2BB3I =
+        new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
     R2BB3->appendRecipe(R2BB3I);
     VPRegionBlock *R2 = Plan.createReplicateRegion(R2BB1, R2BB3, "R2");
     R2BB2->setParent(R2);

diff  --git a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp 
b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
index 6b1e005f4de57..b4ba5f8c50ae4 100644
--- a/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/VPlanTest.cpp
@@ -73,9 +73,13 @@ define void @f(i32 %x) {
 }
 
 TEST_F(VPInstructionTest, insertBefore) {
-  VPInstruction *I1 = new VPInstruction(VPInstruction::StepVector, {});
-  VPInstruction *I2 = new VPInstruction(VPInstruction::VScale, {});
-  VPInstruction *I3 = new VPInstruction(VPInstruction::StepVector, {});
+  IntegerType *Int32 = IntegerType::get(C, 32);
+  VPInstruction *I1 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
+  VPInstruction *I2 =
+      new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
+  VPInstruction *I3 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
 
   VPBasicBlock &VPBB1 = *getPlan().createVPBasicBlock("");
   VPBB1.appendRecipe(I1);
@@ -88,9 +92,13 @@ TEST_F(VPInstructionTest, insertBefore) {
 }
 
 TEST_F(VPInstructionTest, eraseFromParent) {
-  VPInstruction *I1 = new VPInstruction(VPInstruction::StepVector, {});
-  VPInstruction *I2 = new VPInstruction(VPInstruction::VScale, {});
-  VPInstruction *I3 = new VPInstruction(VPInstruction::StepVector, {});
+  IntegerType *Int32 = IntegerType::get(C, 32);
+  VPInstruction *I1 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
+  VPInstruction *I2 =
+      new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
+  VPInstruction *I3 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
 
   VPBasicBlock &VPBB1 = *getPlan().createVPBasicBlock("");
   VPBB1.appendRecipe(I1);
@@ -108,9 +116,13 @@ TEST_F(VPInstructionTest, eraseFromParent) {
 }
 
 TEST_F(VPInstructionTest, moveAfter) {
-  VPInstruction *I1 = new VPInstruction(VPInstruction::StepVector, {});
-  VPInstruction *I2 = new VPInstruction(VPInstruction::VScale, {});
-  VPInstruction *I3 = new VPInstruction(VPInstruction::StepVector, {});
+  IntegerType *Int32 = IntegerType::get(C, 32);
+  VPInstruction *I1 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
+  VPInstruction *I2 =
+      new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
+  VPInstruction *I3 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
 
   VPBasicBlock &VPBB1 = *getPlan().createVPBasicBlock("");
   VPBB1.appendRecipe(I1);
@@ -121,8 +133,10 @@ TEST_F(VPInstructionTest, moveAfter) {
 
   CHECK_ITERATOR(VPBB1, I2, I1, I3);
 
-  VPInstruction *I4 = new VPInstruction(VPInstruction::VScale, {});
-  VPInstruction *I5 = new VPInstruction(VPInstruction::StepVector, {});
+  VPInstruction *I4 =
+      new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
+  VPInstruction *I5 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
   VPBasicBlock &VPBB2 = *getPlan().createVPBasicBlock("");
   VPBB2.appendRecipe(I4);
   VPBB2.appendRecipe(I5);
@@ -135,9 +149,13 @@ TEST_F(VPInstructionTest, moveAfter) {
 }
 
 TEST_F(VPInstructionTest, moveBefore) {
-  VPInstruction *I1 = new VPInstruction(VPInstruction::StepVector, {});
-  VPInstruction *I2 = new VPInstruction(VPInstruction::VScale, {});
-  VPInstruction *I3 = new VPInstruction(VPInstruction::StepVector, {});
+  IntegerType *Int32 = IntegerType::get(C, 32);
+  VPInstruction *I1 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
+  VPInstruction *I2 =
+      new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
+  VPInstruction *I3 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
 
   VPBasicBlock &VPBB1 = *getPlan().createVPBasicBlock("");
   VPBB1.appendRecipe(I1);
@@ -148,8 +166,10 @@ TEST_F(VPInstructionTest, moveBefore) {
 
   CHECK_ITERATOR(VPBB1, I2, I1, I3);
 
-  VPInstruction *I4 = new VPInstruction(VPInstruction::VScale, {});
-  VPInstruction *I5 = new VPInstruction(VPInstruction::StepVector, {});
+  VPInstruction *I4 =
+      new VPInstructionWithType(VPInstruction::VScale, {}, Int32);
+  VPInstruction *I5 =
+      new VPInstructionWithType(VPInstruction::StepVector, {}, Int32);
   VPBasicBlock &VPBB2 = *getPlan().createVPBasicBlock("");
   VPBB2.appendRecipe(I4);
   VPBB2.appendRecipe(I5);
@@ -794,7 +814,8 @@ TEST_F(VPBasicBlockTest, reassociateBlocks) {
 
 TEST_F(VPBasicBlockTest, splitAtEnd) {
   VPlan &Plan = getPlan();
-  VPInstruction *VPI = new VPInstruction(VPInstruction::StepVector, {});
+  VPInstruction *VPI = new VPInstructionWithType(VPInstruction::StepVector, {},
+                                                 IntegerType::get(C, 32));
   VPBasicBlock *VPBB = Plan.createVPBasicBlock("VPBB1", VPI);
   VPBlockUtils::connectBlocks(Plan.getEntry(), VPBB);
   VPBlockUtils::connectBlocks(VPBB, Plan.getScalarHeader());
@@ -1763,8 +1784,9 @@ TEST(VPDoubleValueDefTest, traverseUseLists) {
 
   // Create a new VPRecipeBase which defines 2 values and has 2 operands.
   LLVMContext C;
-  VPInstruction Op0(VPInstruction::StepVector, {});
-  VPInstruction Op1(VPInstruction::VScale, {});
+  IntegerType *Int32 = IntegerType::get(C, 32);
+  VPInstructionWithType Op0(VPInstruction::StepVector, {}, Int32);
+  VPInstructionWithType Op1(VPInstruction::VScale, {}, Int32);
   VPDoubleValueDef DoubleValueDef({&Op0, &Op1}, IntegerType::get(C, 32));
 
   // Create a new users of the defined values.
@@ -1820,7 +1842,8 @@ TEST_F(VPInstructionTest, VPSymbolicValueMaterialization) 
{
 
   // Create a recipe that uses VF.
   VPValue *VF = &Plan.getVF();
-  VPInstruction *I = new VPInstruction(VPInstruction::StepVector, {});
+  VPInstruction *I = new VPInstructionWithType(VPInstruction::StepVector, {},
+                                               IntegerType::get(C, 32));
   VPBasicBlock &VPBB = *Plan.createVPBasicBlock("");
   VPBB.appendRecipe(I);
   I->addOperand(VF);
@@ -1838,8 +1861,8 @@ TEST_F(VPUtilsTest, 
IsUniformAcrossVFsAndUFsForSingleScalarOpcodes) {
   VPlan &Plan = getPlan();
 
   // isSingleScalar opcode without operands.
-  std::unique_ptr<VPInstruction> VScale(
-      new VPInstruction(VPInstruction::VScale, {}));
+  std::unique_ptr<VPInstruction> VScale(new VPInstructionWithType(
+      VPInstruction::VScale, {}, IntegerType::get(C, 32)));
   EXPECT_TRUE(vputils::isUniformAcrossVFsAndUFs(VScale.get()));
 
   // isSingleScalar opcode with a uniform operand.
@@ -1849,13 +1872,14 @@ TEST_F(VPUtilsTest, 
IsUniformAcrossVFsAndUFsForSingleScalarOpcodes) {
 
   // isVectorToScalar opcode with a uniform operand.
   std::unique_ptr<VPInstruction> FirstActiveLane(
-      new VPInstruction(VPInstruction::FirstActiveLane, {&Plan.getVF()}));
+      new VPInstructionWithType(VPInstruction::FirstActiveLane, 
{&Plan.getVF()},
+                                IntegerType::get(C, 32)));
   EXPECT_TRUE(vputils::isUniformAcrossVFsAndUFs(FirstActiveLane.get()));
 
   // StepVector produces a distinct value per lane and is non-uniform; use it
   // as the non-single-scalar operand in the negative cases below.
-  std::unique_ptr<VPInstruction> StepVector(
-      new VPInstruction(VPInstruction::StepVector, {}));
+  std::unique_ptr<VPInstruction> StepVector(new VPInstructionWithType(
+      VPInstruction::StepVector, {}, IntegerType::get(C, 32)));
   EXPECT_FALSE(vputils::isUniformAcrossVFsAndUFs(StepVector.get()));
 
   // isSingleScalar opcode with a non-single-scalar operand.
@@ -1865,7 +1889,8 @@ TEST_F(VPUtilsTest, 
IsUniformAcrossVFsAndUFsForSingleScalarOpcodes) {
 
   // isVectorToScalar opcode with a non-single-scalar operand.
   std::unique_ptr<VPInstruction> FirstActiveLaneNonUniform(
-      new VPInstruction(VPInstruction::FirstActiveLane, {StepVector.get()}));
+      new VPInstructionWithType(VPInstruction::FirstActiveLane,
+                                {StepVector.get()}, IntegerType::get(C, 32)));
   EXPECT_FALSE(
       vputils::isUniformAcrossVFsAndUFs(FirstActiveLaneNonUniform.get()));
 }
@@ -1880,7 +1905,8 @@ TEST_F(VPInstructionTest, 
VPSymbolicValueAddUserAfterMaterialization) {
   EXPECT_TRUE(Plan.getVF().isMaterialized());
 
   // Adding a new user to a materialized value should crash.
-  VPInstruction *I = new VPInstruction(VPInstruction::StepVector, {});
+  VPInstruction *I = new VPInstructionWithType(VPInstruction::StepVector, {},
+                                               IntegerType::get(C, 32));
   VPBasicBlock &VPBB = *Plan.createVPBasicBlock("");
   VPBB.appendRecipe(I);
   EXPECT_DEATH(I->addOperand(VF), "accessing materialized symbolic value");
@@ -1891,17 +1917,17 @@ TEST_F(VPRecipeTest, UFVScaleUserBeforeMaterialization) 
{
   VPlan &Plan = getPlan();
   VPBasicBlock *Header = Plan.createVPBasicBlock("vector.header");
   VPBasicBlock *Latch = Plan.createVPBasicBlock("vector.latch");
-  VPRegionBlock *LoopRegion =
-      Plan.createLoopRegion(Type::getInt32Ty(C), DebugLoc::getUnknown(),
-                            "vector.loop", Header, Latch);
+  VPValue *UF = &Plan.getUF();
+  Type *IVTy = UF->getScalarType();
+  VPRegionBlock *LoopRegion = Plan.createLoopRegion(
+      IVTy, DebugLoc::getUnknown(), "vector.loop", Header, Latch);
   VPBlockUtils::connectBlocks(Header, Latch);
   VPBlockUtils::connectBlocks(Plan.getEntry(), LoopRegion);
   VPBlockUtils::connectBlocks(LoopRegion, Plan.getScalarHeader());
 
-  auto *VScale = new VPInstruction(VPInstruction::VScale, {});
+  auto *VScale = new VPInstructionWithType(VPInstruction::VScale, {}, IVTy);
   Plan.getVectorPreheader()->appendRecipe(VScale);
 
-  VPValue *UF = &Plan.getUF();
   auto *Step = new VPInstruction(Instruction::Mul, {VScale, UF},
                                  VPIRFlags::getDefaultFlags(Instruction::Mul));
   Plan.getVectorPreheader()->appendRecipe(Step);


        
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to