//===-- RISCVLegalizerInfo.cpp ----------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file /// This file implements the targeting of the Machinelegalizer class for RISC-V. /// \todo This should be generated by TableGen. //===----------------------------------------------------------------------===// #include "RISCVLegalizerInfo.h" #include "MCTargetDesc/RISCVMatInt.h" #include "RISCVMachineFunctionInfo.h" #include "RISCVSubtarget.h" #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h" #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Type.h" using namespace llvm; using namespace LegalityPredicates; using namespace LegalizeMutations; // Is this type supported by scalar FP arithmetic operations given the current // subtarget. static LegalityPredicate typeIsScalarFPArith(unsigned TypeIdx, const RISCVSubtarget &ST) { return [=, &ST](const LegalityQuery &Query) { return Query.Types[TypeIdx].isScalar() && ((ST.hasStdExtZfh() && Query.Types[TypeIdx].getSizeInBits() == 16) || (ST.hasStdExtF() && Query.Types[TypeIdx].getSizeInBits() == 32) || (ST.hasStdExtD() && Query.Types[TypeIdx].getSizeInBits() == 64)); }; } static LegalityPredicate typeIsLegalIntOrFPVec(unsigned TypeIdx, std::initializer_list IntOrFPVecTys, const RISCVSubtarget &ST) { LegalityPredicate P = [=, &ST](const LegalityQuery &Query) { return ST.hasVInstructions() && (Query.Types[TypeIdx].getScalarSizeInBits() != 64 || ST.hasVInstructionsI64()) && (Query.Types[TypeIdx].getElementCount().getKnownMinValue() != 1 || ST.getELen() == 64); }; return all(typeInSet(TypeIdx, IntOrFPVecTys), P); } static LegalityPredicate typeIsLegalBoolVec(unsigned TypeIdx, std::initializer_list BoolVecTys, const RISCVSubtarget &ST) { LegalityPredicate P = [=, &ST](const LegalityQuery &Query) { return ST.hasVInstructions() && (Query.Types[TypeIdx].getElementCount().getKnownMinValue() != 1 || ST.getELen() == 64); }; return all(typeInSet(TypeIdx, BoolVecTys), P); } RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) : STI(ST), XLen(STI.getXLen()), sXLen(LLT::scalar(XLen)) { const LLT sDoubleXLen = LLT::scalar(2 * XLen); const LLT p0 = LLT::pointer(0, XLen); const LLT s1 = LLT::scalar(1); const LLT s8 = LLT::scalar(8); const LLT s16 = LLT::scalar(16); const LLT s32 = LLT::scalar(32); const LLT s64 = LLT::scalar(64); const LLT nxv1s1 = LLT::scalable_vector(1, s1); const LLT nxv2s1 = LLT::scalable_vector(2, s1); const LLT nxv4s1 = LLT::scalable_vector(4, s1); const LLT nxv8s1 = LLT::scalable_vector(8, s1); const LLT nxv16s1 = LLT::scalable_vector(16, s1); const LLT nxv32s1 = LLT::scalable_vector(32, s1); const LLT nxv64s1 = LLT::scalable_vector(64, s1); const LLT nxv1s8 = LLT::scalable_vector(1, s8); const LLT nxv2s8 = LLT::scalable_vector(2, s8); const LLT nxv4s8 = LLT::scalable_vector(4, s8); const LLT nxv8s8 = LLT::scalable_vector(8, s8); const LLT nxv16s8 = LLT::scalable_vector(16, s8); const LLT nxv32s8 = LLT::scalable_vector(32, s8); const LLT nxv64s8 = LLT::scalable_vector(64, s8); const LLT nxv1s16 = LLT::scalable_vector(1, s16); const LLT nxv2s16 = LLT::scalable_vector(2, s16); const LLT nxv4s16 = LLT::scalable_vector(4, s16); const LLT nxv8s16 = LLT::scalable_vector(8, s16); const LLT nxv16s16 = LLT::scalable_vector(16, s16); const LLT nxv32s16 = LLT::scalable_vector(32, s16); const LLT nxv1s32 = LLT::scalable_vector(1, s32); const LLT nxv2s32 = LLT::scalable_vector(2, s32); const LLT nxv4s32 = LLT::scalable_vector(4, s32); const LLT nxv8s32 = LLT::scalable_vector(8, s32); const LLT nxv16s32 = LLT::scalable_vector(16, s32); const LLT nxv1s64 = LLT::scalable_vector(1, s64); const LLT nxv2s64 = LLT::scalable_vector(2, s64); const LLT nxv4s64 = LLT::scalable_vector(4, s64); const LLT nxv8s64 = LLT::scalable_vector(8, s64); using namespace TargetOpcode; auto BoolVecTys = {nxv1s1, nxv2s1, nxv4s1, nxv8s1, nxv16s1, nxv32s1, nxv64s1}; auto IntOrFPVecTys = {nxv1s8, nxv2s8, nxv4s8, nxv8s8, nxv16s8, nxv32s8, nxv64s8, nxv1s16, nxv2s16, nxv4s16, nxv8s16, nxv16s16, nxv32s16, nxv1s32, nxv2s32, nxv4s32, nxv8s32, nxv16s32, nxv1s64, nxv2s64, nxv4s64, nxv8s64}; getActionDefinitionsBuilder({G_ADD, G_SUB, G_AND, G_OR, G_XOR}) .legalFor({s32, sXLen}) .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST)) .widenScalarToNextPow2(0) .clampScalar(0, s32, sXLen); getActionDefinitionsBuilder( {G_UADDE, G_UADDO, G_USUBE, G_USUBO}).lower(); getActionDefinitionsBuilder({G_SADDO, G_SSUBO}).minScalar(0, sXLen).lower(); // TODO: Use Vector Single-Width Saturating Instructions for vector types. getActionDefinitionsBuilder({G_UADDSAT, G_SADDSAT, G_USUBSAT, G_SSUBSAT}) .lower(); auto &ShiftActions = getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL}); if (ST.is64Bit()) ShiftActions.customFor({{s32, s32}}); ShiftActions.legalFor({{s32, s32}, {s32, sXLen}, {sXLen, sXLen}}) .widenScalarToNextPow2(0) .clampScalar(1, s32, sXLen) .clampScalar(0, s32, sXLen) .minScalarSameAs(1, 0) .widenScalarToNextPow2(1); auto &ExtActions = getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT}) .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST), typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))); if (ST.is64Bit()) { ExtActions.legalFor({{sXLen, s32}}); getActionDefinitionsBuilder(G_SEXT_INREG) .customFor({sXLen}) .maxScalar(0, sXLen) .lower(); } else { getActionDefinitionsBuilder(G_SEXT_INREG).maxScalar(0, sXLen).lower(); } ExtActions.customIf(typeIsLegalBoolVec(1, BoolVecTys, ST)) .maxScalar(0, sXLen); // Merge/Unmerge for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) { auto &MergeUnmergeActions = getActionDefinitionsBuilder(Op); unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1; unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0; if (XLen == 32 && ST.hasStdExtD()) { MergeUnmergeActions.legalIf( all(typeIs(BigTyIdx, s64), typeIs(LitTyIdx, s32))); } MergeUnmergeActions.widenScalarToNextPow2(LitTyIdx, XLen) .widenScalarToNextPow2(BigTyIdx, XLen) .clampScalar(LitTyIdx, sXLen, sXLen) .clampScalar(BigTyIdx, sXLen, sXLen); } getActionDefinitionsBuilder({G_FSHL, G_FSHR}).lower(); auto &RotateActions = getActionDefinitionsBuilder({G_ROTL, G_ROTR}); if (ST.hasStdExtZbb() || ST.hasStdExtZbkb()) { RotateActions.legalFor({{s32, sXLen}, {sXLen, sXLen}}); // Widen s32 rotate amount to s64 so SDAG patterns will match. if (ST.is64Bit()) RotateActions.widenScalarIf(all(typeIs(0, s32), typeIs(1, s32)), changeTo(1, sXLen)); } RotateActions.lower(); getActionDefinitionsBuilder(G_BITREVERSE).maxScalar(0, sXLen).lower(); getActionDefinitionsBuilder(G_BITCAST).legalIf( all(LegalityPredicates::any(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST), typeIsLegalBoolVec(0, BoolVecTys, ST)), LegalityPredicates::any(typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST), typeIsLegalBoolVec(1, BoolVecTys, ST)))); auto &BSWAPActions = getActionDefinitionsBuilder(G_BSWAP); if (ST.hasStdExtZbb() || ST.hasStdExtZbkb()) BSWAPActions.legalFor({sXLen}).clampScalar(0, sXLen, sXLen); else BSWAPActions.maxScalar(0, sXLen).lower(); auto &CountZerosActions = getActionDefinitionsBuilder({G_CTLZ, G_CTTZ}); auto &CountZerosUndefActions = getActionDefinitionsBuilder({G_CTLZ_ZERO_UNDEF, G_CTTZ_ZERO_UNDEF}); if (ST.hasStdExtZbb()) { CountZerosActions.legalFor({{s32, s32}, {sXLen, sXLen}}) .clampScalar(0, s32, sXLen) .widenScalarToNextPow2(0) .scalarSameSizeAs(1, 0); } else { CountZerosActions.maxScalar(0, sXLen).scalarSameSizeAs(1, 0).lower(); CountZerosUndefActions.maxScalar(0, sXLen).scalarSameSizeAs(1, 0); } CountZerosUndefActions.lower(); auto &CTPOPActions = getActionDefinitionsBuilder(G_CTPOP); if (ST.hasStdExtZbb()) { CTPOPActions.legalFor({{s32, s32}, {sXLen, sXLen}}) .clampScalar(0, s32, sXLen) .widenScalarToNextPow2(0) .scalarSameSizeAs(1, 0); } else { CTPOPActions.maxScalar(0, sXLen).scalarSameSizeAs(1, 0).lower(); } auto &ConstantActions = getActionDefinitionsBuilder(G_CONSTANT); ConstantActions.legalFor({s32, p0}); if (ST.is64Bit()) ConstantActions.customFor({s64}); ConstantActions.widenScalarToNextPow2(0).clampScalar(0, s32, sXLen); // TODO: transform illegal vector types into legal vector type getActionDefinitionsBuilder( {G_IMPLICIT_DEF, G_CONSTANT_FOLD_BARRIER, G_FREEZE}) .legalFor({s32, sXLen, p0}) .legalIf(typeIsLegalBoolVec(0, BoolVecTys, ST)) .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST)) .widenScalarToNextPow2(0) .clampScalar(0, s32, sXLen); getActionDefinitionsBuilder(G_ICMP) .legalFor({{sXLen, sXLen}, {sXLen, p0}}) .legalIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST))) .widenScalarOrEltToNextPow2OrMinSize(1, 8) .clampScalar(1, sXLen, sXLen) .clampScalar(0, sXLen, sXLen); auto &SelectActions = getActionDefinitionsBuilder(G_SELECT) .legalFor({{s32, sXLen}, {p0, sXLen}}) .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST), typeIsLegalBoolVec(1, BoolVecTys, ST))); if (XLen == 64 || ST.hasStdExtD()) SelectActions.legalFor({{s64, sXLen}}); SelectActions.widenScalarToNextPow2(0) .clampScalar(0, s32, (XLen == 64 || ST.hasStdExtD()) ? s64 : s32) .clampScalar(1, sXLen, sXLen); auto &LoadStoreActions = getActionDefinitionsBuilder({G_LOAD, G_STORE}) .legalForTypesWithMemDesc({{s32, p0, s8, 8}, {s32, p0, s16, 16}, {s32, p0, s32, 32}, {p0, p0, sXLen, XLen}}); auto &ExtLoadActions = getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD}) .legalForTypesWithMemDesc({{s32, p0, s8, 8}, {s32, p0, s16, 16}}); if (XLen == 64) { LoadStoreActions.legalForTypesWithMemDesc({{s64, p0, s8, 8}, {s64, p0, s16, 16}, {s64, p0, s32, 32}, {s64, p0, s64, 64}}); ExtLoadActions.legalForTypesWithMemDesc( {{s64, p0, s8, 8}, {s64, p0, s16, 16}, {s64, p0, s32, 32}}); } else if (ST.hasStdExtD()) { LoadStoreActions.legalForTypesWithMemDesc({{s64, p0, s64, 64}}); } LoadStoreActions.clampScalar(0, s32, sXLen).lower(); ExtLoadActions.widenScalarToNextPow2(0).clampScalar(0, s32, sXLen).lower(); getActionDefinitionsBuilder({G_PTR_ADD, G_PTRMASK}).legalFor({{p0, sXLen}}); getActionDefinitionsBuilder(G_PTRTOINT) .legalFor({{sXLen, p0}}) .clampScalar(0, sXLen, sXLen); getActionDefinitionsBuilder(G_INTTOPTR) .legalFor({{p0, sXLen}}) .clampScalar(1, sXLen, sXLen); getActionDefinitionsBuilder(G_BRCOND).legalFor({sXLen}).minScalar(0, sXLen); getActionDefinitionsBuilder(G_BRJT).legalFor({{p0, sXLen}}); getActionDefinitionsBuilder(G_BRINDIRECT).legalFor({p0}); getActionDefinitionsBuilder(G_PHI) .legalFor({p0, sXLen}) .widenScalarToNextPow2(0) .clampScalar(0, sXLen, sXLen); getActionDefinitionsBuilder({G_GLOBAL_VALUE, G_JUMP_TABLE, G_CONSTANT_POOL}) .legalFor({p0}); if (ST.hasStdExtZmmul()) { getActionDefinitionsBuilder(G_MUL) .legalFor({s32, sXLen}) .widenScalarToNextPow2(0) .clampScalar(0, s32, sXLen); // clang-format off getActionDefinitionsBuilder({G_SMULH, G_UMULH}) .legalFor({sXLen}) .lower(); // clang-format on getActionDefinitionsBuilder({G_SMULO, G_UMULO}).minScalar(0, sXLen).lower(); } else { getActionDefinitionsBuilder(G_MUL) .libcallFor({sXLen, sDoubleXLen}) .widenScalarToNextPow2(0) .clampScalar(0, sXLen, sDoubleXLen); getActionDefinitionsBuilder({G_SMULH, G_UMULH}).lowerFor({sXLen}); getActionDefinitionsBuilder({G_SMULO, G_UMULO}) .minScalar(0, sXLen) // Widen sXLen to sDoubleXLen so we can use a single libcall to get // the low bits for the mul result and high bits to do the overflow // check. .widenScalarIf(typeIs(0, sXLen), LegalizeMutations::changeTo(0, sDoubleXLen)) .lower(); } if (ST.hasStdExtM()) { getActionDefinitionsBuilder({G_UDIV, G_SDIV, G_UREM, G_SREM}) .legalFor({s32, sXLen}) .libcallFor({sDoubleXLen}) .clampScalar(0, s32, sDoubleXLen) .widenScalarToNextPow2(0); } else { getActionDefinitionsBuilder({G_UDIV, G_SDIV, G_UREM, G_SREM}) .libcallFor({sXLen, sDoubleXLen}) .clampScalar(0, sXLen, sDoubleXLen) .widenScalarToNextPow2(0); } // TODO: Use libcall for sDoubleXLen. getActionDefinitionsBuilder({G_UDIVREM, G_SDIVREM}).lower(); auto &AbsActions = getActionDefinitionsBuilder(G_ABS); if (ST.hasStdExtZbb()) AbsActions.customFor({s32, sXLen}).minScalar(0, sXLen); AbsActions.lower(); auto &MinMaxActions = getActionDefinitionsBuilder({G_UMAX, G_UMIN, G_SMAX, G_SMIN}); if (ST.hasStdExtZbb()) MinMaxActions.legalFor({sXLen}).minScalar(0, sXLen); MinMaxActions.lower(); getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE, G_MEMSET}).libcall(); getActionDefinitionsBuilder(G_DYN_STACKALLOC).lower(); // FP Operations getActionDefinitionsBuilder({G_FADD, G_FSUB, G_FMUL, G_FDIV, G_FMA, G_FNEG, G_FABS, G_FSQRT, G_FMAXNUM, G_FMINNUM}) .legalIf(typeIsScalarFPArith(0, ST)); getActionDefinitionsBuilder(G_FREM) .libcallFor({s32, s64}) .minScalar(0, s32) .scalarize(0); getActionDefinitionsBuilder(G_FCOPYSIGN) .legalIf(all(typeIsScalarFPArith(0, ST), typeIsScalarFPArith(1, ST))); // FIXME: Use Zfhmin. getActionDefinitionsBuilder(G_FPTRUNC).legalIf( [=, &ST](const LegalityQuery &Query) -> bool { return (ST.hasStdExtD() && typeIs(0, s32)(Query) && typeIs(1, s64)(Query)) || (ST.hasStdExtZfh() && typeIs(0, s16)(Query) && typeIs(1, s32)(Query)) || (ST.hasStdExtZfh() && ST.hasStdExtD() && typeIs(0, s16)(Query) && typeIs(1, s64)(Query)); }); getActionDefinitionsBuilder(G_FPEXT).legalIf( [=, &ST](const LegalityQuery &Query) -> bool { return (ST.hasStdExtD() && typeIs(0, s64)(Query) && typeIs(1, s32)(Query)) || (ST.hasStdExtZfh() && typeIs(0, s32)(Query) && typeIs(1, s16)(Query)) || (ST.hasStdExtZfh() && ST.hasStdExtD() && typeIs(0, s64)(Query) && typeIs(1, s16)(Query)); }); getActionDefinitionsBuilder(G_FCMP) .legalIf(all(typeIs(0, sXLen), typeIsScalarFPArith(1, ST))) .clampScalar(0, sXLen, sXLen); // TODO: Support vector version of G_IS_FPCLASS. getActionDefinitionsBuilder(G_IS_FPCLASS) .customIf(all(typeIs(0, s1), typeIsScalarFPArith(1, ST))); getActionDefinitionsBuilder(G_FCONSTANT) .legalIf(typeIsScalarFPArith(0, ST)) .lowerFor({s32, s64}); getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) .legalIf(all(typeInSet(0, {s32, sXLen}), typeIsScalarFPArith(1, ST))) .widenScalarToNextPow2(0) .clampScalar(0, s32, sXLen) .libcall(); getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) .legalIf(all(typeIsScalarFPArith(0, ST), typeInSet(1, {s32, sXLen}))) .widenScalarToNextPow2(1) .clampScalar(1, s32, sXLen); // FIXME: We can do custom inline expansion like SelectionDAG. // FIXME: Legal with Zfa. getActionDefinitionsBuilder({G_FCEIL, G_FFLOOR}) .libcallFor({s32, s64}); getActionDefinitionsBuilder(G_VASTART).customFor({p0}); // va_list must be a pointer, but most sized types are pretty easy to handle // as the destination. getActionDefinitionsBuilder(G_VAARG) // TODO: Implement narrowScalar and widenScalar for G_VAARG for types // outside the [s32, sXLen] range. .clampScalar(0, s32, sXLen) .lowerForCartesianProduct({s32, sXLen, p0}, {p0}); getActionDefinitionsBuilder(G_VSCALE) .clampScalar(0, sXLen, sXLen) .customFor({sXLen}); auto &SplatActions = getActionDefinitionsBuilder(G_SPLAT_VECTOR) .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST), typeIs(1, sXLen))) .customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), typeIs(1, s1))); // Handle case of s64 element vectors on RV32. If the subtarget does not have // f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget // does have f64, then we don't know whether the type is an f64 or an i64, // so mark the G_SPLAT_VECTOR as legal and decide later what to do with it, // depending on how the instructions it consumes are legalized. They are not // legalized yet since legalization is in reverse postorder, so we cannot // make the decision at this moment. if (XLen == 32) { if (ST.hasVInstructionsF64() && ST.hasStdExtD()) SplatActions.legalIf(all( typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64))); else if (ST.hasVInstructionsI64()) SplatActions.customIf(all( typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64))); } SplatActions.clampScalar(1, sXLen, sXLen); getLegacyLegalizerInfo().computeTables(); } bool RISCVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, MachineInstr &MI) const { Intrinsic::ID IntrinsicID = cast(MI).getIntrinsicID(); switch (IntrinsicID) { default: return false; case Intrinsic::vacopy: { // vacopy arguments must be legal because of the intrinsic signature. // No need to check here. MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); MachineFunction &MF = *MI.getMF(); const DataLayout &DL = MIRBuilder.getDataLayout(); LLVMContext &Ctx = MF.getFunction().getContext(); Register DstLst = MI.getOperand(1).getReg(); LLT PtrTy = MRI.getType(DstLst); // Load the source va_list Align Alignment = DL.getABITypeAlign(getTypeForLLT(PtrTy, Ctx)); MachineMemOperand *LoadMMO = MF.getMachineMemOperand( MachinePointerInfo(), MachineMemOperand::MOLoad, PtrTy, Alignment); auto Tmp = MIRBuilder.buildLoad(PtrTy, MI.getOperand(2), *LoadMMO); // Store the result in the destination va_list MachineMemOperand *StoreMMO = MF.getMachineMemOperand( MachinePointerInfo(), MachineMemOperand::MOStore, PtrTy, Alignment); MIRBuilder.buildStore(Tmp, DstLst, *StoreMMO); MI.eraseFromParent(); return true; } } } bool RISCVLegalizerInfo::legalizeShlAshrLshr( MachineInstr &MI, MachineIRBuilder &MIRBuilder, GISelChangeObserver &Observer) const { assert(MI.getOpcode() == TargetOpcode::G_ASHR || MI.getOpcode() == TargetOpcode::G_LSHR || MI.getOpcode() == TargetOpcode::G_SHL); MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); // If the shift amount is a G_CONSTANT, promote it to a 64 bit type so the // imported patterns can select it later. Either way, it will be legal. Register AmtReg = MI.getOperand(2).getReg(); auto VRegAndVal = getIConstantVRegValWithLookThrough(AmtReg, MRI); if (!VRegAndVal) return true; // Check the shift amount is in range for an immediate form. uint64_t Amount = VRegAndVal->Value.getZExtValue(); if (Amount > 31) return true; // This will have to remain a register variant. auto ExtCst = MIRBuilder.buildConstant(LLT::scalar(64), Amount); Observer.changingInstr(MI); MI.getOperand(2).setReg(ExtCst.getReg(0)); Observer.changedInstr(MI); return true; } bool RISCVLegalizerInfo::legalizeVAStart(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const { // Stores the address of the VarArgsFrameIndex slot into the memory location assert(MI.getOpcode() == TargetOpcode::G_VASTART); MachineFunction *MF = MI.getParent()->getParent(); RISCVMachineFunctionInfo *FuncInfo = MF->getInfo(); int FI = FuncInfo->getVarArgsFrameIndex(); LLT AddrTy = MIRBuilder.getMRI()->getType(MI.getOperand(0).getReg()); auto FINAddr = MIRBuilder.buildFrameIndex(AddrTy, FI); assert(MI.hasOneMemOperand()); MIRBuilder.buildStore(FINAddr, MI.getOperand(0).getReg(), *MI.memoperands()[0]); MI.eraseFromParent(); return true; } bool RISCVLegalizerInfo::shouldBeInConstantPool(APInt APImm, bool ShouldOptForSize) const { assert(APImm.getBitWidth() == 32 || APImm.getBitWidth() == 64); int64_t Imm = APImm.getSExtValue(); // All simm32 constants should be handled by isel. // NOTE: The getMaxBuildIntsCost call below should return a value >= 2 making // this check redundant, but small immediates are common so this check // should have better compile time. if (isInt<32>(Imm)) return false; // We only need to cost the immediate, if constant pool lowering is enabled. if (!STI.useConstantPoolForLargeInts()) return false; RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Imm, STI); if (Seq.size() <= STI.getMaxBuildIntsCost()) return false; // Optimizations below are disabled for opt size. If we're optimizing for // size, use a constant pool. if (ShouldOptForSize) return true; // // Special case. See if we can build the constant as (ADD (SLLI X, C), X) do // that if it will avoid a constant pool. // It will require an extra temporary register though. // If we have Zba we can use (ADD_UW X, (SLLI X, 32)) to handle cases where // low and high 32 bits are the same and bit 31 and 63 are set. unsigned ShiftAmt, AddOpc; RISCVMatInt::InstSeq SeqLo = RISCVMatInt::generateTwoRegInstSeq(Imm, STI, ShiftAmt, AddOpc); return !(!SeqLo.empty() && (SeqLo.size() + 2) <= STI.getMaxBuildIntsCost()); } bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const { const LLT XLenTy(STI.getXLenVT()); Register Dst = MI.getOperand(0).getReg(); // We define our scalable vector types for lmul=1 to use a 64 bit known // minimum size. e.g. . VLENB is in bytes so we calculate // vscale as VLENB / 8. static_assert(RISCV::RVVBitsPerBlock == 64, "Unexpected bits per block!"); if (STI.getRealMinVLen() < RISCV::RVVBitsPerBlock) // Support for VLEN==32 is incomplete. return false; // We assume VLENB is a multiple of 8. We manually choose the best shift // here because SimplifyDemandedBits isn't always able to simplify it. uint64_t Val = MI.getOperand(1).getCImm()->getZExtValue(); if (isPowerOf2_64(Val)) { uint64_t Log2 = Log2_64(Val); if (Log2 < 3) { auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {}); MIB.buildLShr(Dst, VLENB, MIB.buildConstant(XLenTy, 3 - Log2)); } else if (Log2 > 3) { auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {}); MIB.buildShl(Dst, VLENB, MIB.buildConstant(XLenTy, Log2 - 3)); } else { MIB.buildInstr(RISCV::G_READ_VLENB, {Dst}, {}); } } else if ((Val % 8) == 0) { // If the multiplier is a multiple of 8, scale it down to avoid needing // to shift the VLENB value. auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {}); MIB.buildMul(Dst, VLENB, MIB.buildConstant(XLenTy, Val / 8)); } else { auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {}); auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3)); MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val)); } MI.eraseFromParent(); return true; } // Custom-lower extensions from mask vectors by using a vselect either with 1 // for zero/any-extension or -1 for sign-extension: // (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0) // Note that any-extension is lowered identically to zero-extension. bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI, MachineIRBuilder &MIB) const { unsigned Opc = MI.getOpcode(); assert(Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_ANYEXT); MachineRegisterInfo &MRI = *MIB.getMRI(); Register Dst = MI.getOperand(0).getReg(); Register Src = MI.getOperand(1).getReg(); LLT DstTy = MRI.getType(Dst); int64_t ExtTrueVal = Opc == TargetOpcode::G_SEXT ? -1 : 1; LLT DstEltTy = DstTy.getElementType(); auto SplatZero = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, 0)); auto SplatTrue = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, ExtTrueVal)); MIB.buildSelect(Dst, Src, SplatTrue, SplatZero); MI.eraseFromParent(); return true; } /// Return the type of the mask type suitable for masking the provided /// vector type. This is simply an i1 element type vector of the same /// (possibly scalable) length. static LLT getMaskTypeFor(LLT VecTy) { assert(VecTy.isVector()); ElementCount EC = VecTy.getElementCount(); return LLT::vector(EC, LLT::scalar(1)); } /// Creates an all ones mask suitable for masking a vector of type VecTy with /// vector length VL. static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) { LLT MaskTy = getMaskTypeFor(VecTy); return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL}); } /// Gets the two common "VL" operands: an all-ones mask and the vector length. /// VecTy is a scalable vector type. static std::pair buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) { LLT VecTy = Dst.getLLTTy(MRI); assert(VecTy.isScalableVector() && "Expecting scalable container type"); Register VL(RISCV::X0); MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI); return {Mask, VL}; } static MachineInstrBuilder buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo, Register Hi, Register VL, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) { // TODO: If the Hi bits of the splat are undefined, then it's fine to just // splat Lo even if it might be sign extended. I don't think we have // introduced a case where we're build a s64 where the upper bits are undef // yet. // Fall back to a stack store and stride x0 vector load. // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in // preprocessDAG in SDAG. return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst}, {Passthru, Lo, Hi, VL}); } static MachineInstrBuilder buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru, const SrcOp &Scalar, Register VL, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) { assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!"); auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar); return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0), Unmerge.getReg(1), VL, MIB, MRI); } // Lower splats of s1 types to G_ICMP. For each mask vector type, we have a // legal equivalently-sized i8 type, so we can use that as a go-between. // Splats of s1 types that have constant value can be legalized as VMSET_VL or // VMCLR_VL. bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const { assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR); MachineRegisterInfo &MRI = *MIB.getMRI(); Register Dst = MI.getOperand(0).getReg(); Register SplatVal = MI.getOperand(1).getReg(); LLT VecTy = MRI.getType(Dst); LLT XLenTy(STI.getXLenVT()); // Handle case of s64 element vectors on rv32 if (XLenTy.getSizeInBits() == 32 && VecTy.getElementType().getSizeInBits() == 64) { auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI); buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB, MRI); MI.eraseFromParent(); return true; } // All-zeros or all-ones splats are handled specially. MachineInstr &SplatValMI = *MRI.getVRegDef(SplatVal); if (isAllOnesOrAllOnesSplat(SplatValMI, MRI)) { auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second; MIB.buildInstr(RISCV::G_VMSET_VL, {Dst}, {VL}); MI.eraseFromParent(); return true; } if (isNullOrNullSplat(SplatValMI, MRI)) { auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second; MIB.buildInstr(RISCV::G_VMCLR_VL, {Dst}, {VL}); MI.eraseFromParent(); return true; } // Handle non-constant mask splat (i.e. not sure if it's all zeros or all // ones) by promoting it to an s8 splat. LLT InterEltTy = LLT::scalar(8); LLT InterTy = VecTy.changeElementType(InterEltTy); auto ZExtSplatVal = MIB.buildZExt(InterEltTy, SplatVal); auto And = MIB.buildAnd(InterEltTy, ZExtSplatVal, MIB.buildConstant(InterEltTy, 1)); auto LHS = MIB.buildSplatVector(InterTy, And); auto ZeroSplat = MIB.buildSplatVector(InterTy, MIB.buildConstant(InterEltTy, 0)); MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat); MI.eraseFromParent(); return true; } bool RISCVLegalizerInfo::legalizeCustom( LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const { MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; GISelChangeObserver &Observer = Helper.Observer; MachineFunction &MF = *MI.getParent()->getParent(); switch (MI.getOpcode()) { default: // No idea what to do. return false; case TargetOpcode::G_ABS: return Helper.lowerAbsToMaxNeg(MI); // TODO: G_FCONSTANT case TargetOpcode::G_CONSTANT: { const Function &F = MF.getFunction(); // TODO: if PSI and BFI are present, add " || // llvm::shouldOptForSize(*CurMBB, PSI, BFI)". bool ShouldOptForSize = F.hasOptSize() || F.hasMinSize(); const ConstantInt *ConstVal = MI.getOperand(1).getCImm(); if (!shouldBeInConstantPool(ConstVal->getValue(), ShouldOptForSize)) return true; return Helper.lowerConstant(MI); } case TargetOpcode::G_SHL: case TargetOpcode::G_ASHR: case TargetOpcode::G_LSHR: return legalizeShlAshrLshr(MI, MIRBuilder, Observer); case TargetOpcode::G_SEXT_INREG: { // Source size of 32 is sext.w. int64_t SizeInBits = MI.getOperand(2).getImm(); if (SizeInBits == 32) return true; return Helper.lower(MI, 0, /* Unused hint type */ LLT()) == LegalizerHelper::Legalized; } case TargetOpcode::G_IS_FPCLASS: { Register GISFPCLASS = MI.getOperand(0).getReg(); Register Src = MI.getOperand(1).getReg(); const MachineOperand &ImmOp = MI.getOperand(2); MachineIRBuilder MIB(MI); // Turn LLVM IR's floating point classes to that in RISC-V, // by simply rotating the 10-bit immediate right by two bits. APInt GFpClassImm(10, static_cast(ImmOp.getImm())); auto FClassMask = MIB.buildConstant(sXLen, GFpClassImm.rotr(2).zext(XLen)); auto ConstZero = MIB.buildConstant(sXLen, 0); auto GFClass = MIB.buildInstr(RISCV::G_FCLASS, {sXLen}, {Src}); auto And = MIB.buildAnd(sXLen, GFClass, FClassMask); MIB.buildICmp(CmpInst::ICMP_NE, GISFPCLASS, And, ConstZero); MI.eraseFromParent(); return true; } case TargetOpcode::G_VASTART: return legalizeVAStart(MI, MIRBuilder); case TargetOpcode::G_VSCALE: return legalizeVScale(MI, MIRBuilder); case TargetOpcode::G_ZEXT: case TargetOpcode::G_SEXT: case TargetOpcode::G_ANYEXT: return legalizeExt(MI, MIRBuilder); case TargetOpcode::G_SPLAT_VECTOR: return legalizeSplatVector(MI, MIRBuilder); } llvm_unreachable("expected switch to return"); }