From a50e77ff38f445a41b7275cc496bf310f276c2cc Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 20 May 2026 19:06:41 +0200 Subject: [PATCH] refactorone --- README.md | 40 + src/PIM/Common/CMakeLists.txt | 1 + src/PIM/Common/IR/AddressAnalysis.cpp | 46 + src/PIM/Common/IR/ConstantUtils.cpp | 62 ++ src/PIM/Common/IR/ConstantUtils.hpp | 28 + src/PIM/Common/IR/CoreBlockUtils.cpp | 4 + src/PIM/Common/IR/WeightUtils.cpp | 41 +- src/PIM/Common/PimCommon.hpp | 1 + src/PIM/Compiler/PimBatchEmission.cpp | 161 ++-- src/PIM/Compiler/PimBatchEmission.hpp | 3 + src/PIM/Compiler/PimCodeGen.cpp | 84 +- src/PIM/Compiler/PimWeightEmitter.cpp | 16 +- .../Common/ComputeRegionBuilder.hpp | 21 +- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 13 +- .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 91 +- .../Conversion/ONNXToSpatial/PostPatterns.cpp | 122 +-- .../Conversion/ONNXToSpatial/PostPatterns.hpp | 4 - .../BatchCoreLoweringPatterns.cpp | 118 ++- .../Conversion/SpatialToPim/CMakeLists.txt | 2 + .../SpatialToPim/ChannelLoweringPatterns.cpp | 44 +- .../SpatialToPim/ComputeLikeRegionUtils.cpp | 7 +- .../SpatialToPim/CoreLoweringPatterns.cpp | 115 ++- .../SpatialToPim/CoreLoweringPatterns.hpp | 2 + .../GlobalTensorMaterialization.cpp | 40 +- .../SpatialToPim/ReturnPathNormalization.cpp | 116 ++- .../SpatialToPim/ReturnPathNormalization.hpp | 6 + .../Conversion/SpatialToPim/SpatialToPim.td | 4 +- .../SpatialToPim/SpatialToPimPass.cpp | 46 +- src/PIM/Dialect/Pim/Pim.td | 50 +- src/PIM/Dialect/Pim/PimOps.cpp | 33 + src/PIM/Dialect/Pim/PimOpsAsm.cpp | 211 ++++- src/PIM/Dialect/Pim/PimOpsVerify.cpp | 102 +- .../OpBufferizationInterfaces.cpp | 93 +- src/PIM/Dialect/Spatial/Channels.cpp | 90 +- src/PIM/Dialect/Spatial/Spatial.td | 145 ++- src/PIM/Dialect/Spatial/SpatialOps.cpp | 64 ++ src/PIM/Dialect/Spatial/SpatialOps.hpp | 2 + src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 370 +++----- src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 347 ++++--- .../MaterializeMergeSchedule.cpp | 884 +++++++++++++++--- .../MergeComputeNodesPass.cpp | 56 +- .../MergeComputeNodes/PostMergeCompaction.cpp | 302 +++++- .../MergeComputeNodes/RegularOpCompaction.cpp | 291 ++++-- .../Scheduling/ComputeGraph.cpp | 62 +- .../Scheduling/ComputeInstanceUtils.cpp | 139 ++- .../Scheduling/ComputeInstanceUtils.hpp | 6 +- .../HostConstantFolding/Patterns/Subview.cpp | 34 +- .../MaterializeHostConstantsPass.cpp | 38 +- src/PIM/Pass/PimCodegen/VerificationPass.cpp | 48 +- validation/validate.py | 2 +- 50 files changed, 3420 insertions(+), 1187 deletions(-) create mode 100644 src/PIM/Common/IR/ConstantUtils.cpp create mode 100644 src/PIM/Common/IR/ConstantUtils.hpp diff --git a/README.md b/README.md index b576cb4..b696389 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,46 @@ validate.py \ --crossbar-size 2048 --crossbar-count 256 --core-count 1000 ``` +Each validation run writes debugging artifacts into the benchmark's workspace +directory (for example `validation/operations/gemm/small/`): +- `inputs/` — generated input CSVs used for the run. +- `outputs/` — reference outputs dumped by the native ONNX runner. +- `raptor/` — compiler artifacts: + `*.onnx.mlir`, `dialects/spatial0.mlir`, `dialects/spatial1_dcp_merged.mlir`, + `dialects/pim0.mlir`, `dialects/pim1_buff.mlir`, `dialects/pim2_coalesced.mlir`, + `dialects/pim3_folded.mlir`, `dialects/pim4_materialized.mlir`, + `pim/config.json`, `pim/core_*.pim`, `pim/memory.bin`, and reports under + `raptor/reports/` such as `dcp_merge_report.txt`, + `memory_report.txt`, and `static_memory_coalescing_report.txt`. +- `runner/` — generated reference runner source, build tree, and shared library. +- `simulation/out.bin` — raw simulator output dump used for output comparison. + +That means you usually do not need to rerun standalone `--EmitSpatial` or +`--EmitPim` commands while debugging validation failures: the per-pass dialect +dumps are already available under `raptor/dialects/`. + +The validator does not currently expose a simulator tracing flag, but once a +validation has produced `raptor/pim/` you can rerun the simulator manually with +tracing enabled: + +```bash +cd backend-simulators/pim/pim-simulator +cargo run --no-default-features --features tracing --release \ + --package pim-simulator --bin pim-simulator -- \ + -f /path/to/workspace/raptor/pim \ + -o /path/to/workspace/simulation/out.bin \ + -d ,,,,... +``` + +With `--features tracing`, the simulator writes per-core traces as +`simulation/TraceCore0`, `simulation/TraceCore1`, ... next to `simulation/out.bin`. +The validator normally computes the `-d` dump ranges from `raptor/pim/config.json` +and the model output shapes. If you need a clean slate before rerunning, use: + +```bash +validate.py --clean +``` + Available networks under `validation/networks/`: `vgg16`, `yolo11n`. Available operations under `validation/operations/`: `add`, `conv`, `div`, `gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`, diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 0dce626..8f62537 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -1,5 +1,6 @@ add_pim_library(OMPimCommon IR/AddressAnalysis.cpp + IR/ConstantUtils.cpp IR/CoreBlockUtils.cpp IR/EntryPointUtils.cpp IR/ShapeUtils.cpp diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index 69a4ca2..4257ac8 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -1,5 +1,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" @@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow } llvm::FailureOr resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge); +llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge); + +static llvm::FailureOr resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp, + const StaticValueKnowledge* knowledge) { + auto getGlobalOp = loadOp.getMemRef().getDefiningOp(); + if (!getGlobalOp) + return mlir::failure(); + + auto moduleOp = loadOp->getParentOfType(); + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue()) + return mlir::failure(); + + auto denseAttr = mlir::dyn_cast(*globalOp.getInitialValue()); + auto globalType = mlir::dyn_cast(getGlobalOp.getType()); + if (!denseAttr || !globalType || !globalType.hasStaticShape()) + return mlir::failure(); + + auto elementType = denseAttr.getElementType(); + if (!elementType.isIndex() && !elementType.isInteger()) + return mlir::failure(); + + llvm::SmallVector indices; + indices.reserve(loadOp.getIndices().size()); + for (mlir::Value index : loadOp.getIndices()) { + auto resolvedIndex = resolveIndexValueImpl(index, knowledge); + if (failed(resolvedIndex)) + return mlir::failure(); + indices.push_back(*resolvedIndex); + } + + if (indices.size() != static_cast(globalType.getRank())) + return mlir::failure(); + + auto strides = computeRowMajorStrides(globalType.getShape()); + int64_t linearIndex = linearizeIndex(indices, strides); + if (linearIndex < 0 || linearIndex >= globalType.getNumElements()) + return mlir::failure(); + + return denseAttr.getValues()[linearIndex].getSExtValue(); +} llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); @@ -126,6 +169,9 @@ llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticVa return static_cast(static_cast(*lhs) % static_cast(*rhs)); } + if (auto loadOp = mlir::dyn_cast(definingOp)) + return resolveConstantGlobalLoad(loadOp, knowledge); + return mlir::failure(); } diff --git a/src/PIM/Common/IR/ConstantUtils.cpp b/src/PIM/Common/IR/ConstantUtils.cpp new file mode 100644 index 0000000..91d0989 --- /dev/null +++ b/src/PIM/Common/IR/ConstantUtils.cpp @@ -0,0 +1,62 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +#include "ConstantUtils.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +Block* getHostConstantBlock(Operation* anchorOp) { + assert(anchorOp && "expected a valid anchor operation"); + + for (Operation* current = anchorOp; current; current = current->getParentOp()) + if (isa(current)) + return current->getBlock(); + + if (auto funcOp = anchorOp->getParentOfType()) + return &funcOp.getBody().front(); + if (auto moduleOp = anchorOp->getParentOfType()) + return moduleOp.getBody(); + return anchorOp->getBlock(); +} + +Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) { + assert(anchorOp && "expected a valid anchor operation"); + Block* hostBlock = getHostConstantBlock(anchorOp); + for (Operation& op : *hostBlock) { + auto constantOp = dyn_cast(&op); + if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) + continue; + return constantOp.getResult(); + } + + auto* arithDialect = anchorOp->getContext()->getOrLoadDialect(); + return folder.getOrCreateConstant(hostBlock, arithDialect, value, type); +} + +Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) { + return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder); +} + +Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) { + Builder builder(anchorOp->getContext()); + return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder); +} + +Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) { + Builder builder(anchorOp->getContext()); + return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder); +} + +Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) { + Builder builder(anchorOp->getContext()); + return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ConstantUtils.hpp b/src/PIM/Common/IR/ConstantUtils.hpp new file mode 100644 index 0000000..4754a7d --- /dev/null +++ b/src/PIM/Common/IR/ConstantUtils.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/FoldUtils.h" + +namespace onnx_mlir { + +mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp); + +mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp, + mlir::Attribute value, + mlir::Type type, + mlir::OperationFolder& folder); + +mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder); + +mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); + +mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder); + +mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/CoreBlockUtils.cpp b/src/PIM/Common/IR/CoreBlockUtils.cpp index a5cc241..be78ba1 100644 --- a/src/PIM/Common/IR/CoreBlockUtils.cpp +++ b/src/PIM/Common/IR/CoreBlockUtils.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" @@ -30,6 +31,9 @@ walkPimCoreBlock(mlir::Block& block, for (mlir::Operation& op : block) { if (mlir::isa(op) || isCoreStaticAddressOp(&op)) continue; + if (auto loadOp = mlir::dyn_cast(op); + loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge))) + continue; if (auto forOp = mlir::dyn_cast(op)) { mlir::Block& loopBody = forOp.getRegion().front(); diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index d7079b9..fbeff3e 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -21,12 +21,13 @@ namespace { template bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { + mlir::Value weightArg = parentOp.getWeightArgument(weightIndex); bool found = false; parentOp.walk([&](mlir::Operation* op) { if (auto mvmOp = mlir::dyn_cast(op)) - found |= mvmOp.getWeightIndex() == weightIndex; + found |= mvmOp.getWeight() == weightArg; else if (auto vmmOp = mlir::dyn_cast(op)) - found |= vmmOp.getWeightIndex() == weightIndex; + found |= vmmOp.getWeight() == weightArg; }); return found; } @@ -35,13 +36,18 @@ template void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref callback) { auto weights = parentOp.getWeights(); llvm::SmallSet visited; - auto walkWeightIndex = [&](unsigned weightIndex) { - if (weightIndex < weights.size() && visited.insert(weightIndex).second) - callback(parentOp->getOpOperand(weightIndex)); + auto walkWeight = [&](mlir::Value weight) { + for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) { + if (parentOp.getWeightArgument(weightIndex) != weight) + continue; + if (visited.insert(weightIndex).second) + callback(parentOp->getOpOperand(weightIndex)); + break; + } }; - parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); - parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); + parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); }); + parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); }); } } // namespace @@ -90,18 +96,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_refwalk([&](pim::PimCoreOp coreOp) { coreOp.walk([&](pim::PimVMMOp vmmOp) { - auto weights = coreOp.getWeights(); - unsigned weightIndex = vmmOp.getWeightIndex(); - if (weightIndex < weights.size()) - callback(coreOp->getOpOperand(weightIndex)); + for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) + if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) { + callback(coreOp->getOpOperand(weightIndex)); + break; + } }); }); root->walk([&](pim::PimCoreBatchOp coreBatchOp) { - auto weights = coreBatchOp.getWeights(); - for (auto weight : weights) - for (mlir::OpOperand& use : weight.getUses()) - if (use.getOwner() == coreBatchOp.getOperation()) - callback(use); + coreBatchOp.walk([&](pim::PimVMMOp vmmOp) { + for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex) + if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) { + callback(coreBatchOp->getOpOperand(weightIndex)); + break; + } + }); }); } diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index 6880012..d7f6b6b 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -12,6 +12,7 @@ #include "llvm/ADT/StringRef.h" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" diff --git a/src/PIM/Compiler/PimBatchEmission.cpp b/src/PIM/Compiler/PimBatchEmission.cpp index 656a57b..6bbf0d1 100644 --- a/src/PIM/Compiler/PimBatchEmission.cpp +++ b/src/PIM/Compiler/PimBatchEmission.cpp @@ -1,7 +1,11 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "llvm/ADT/StringRef.h" + #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" @@ -24,113 +28,132 @@ static SmallVector getLaneChunkCoreIds(ArrayRef coreIds, size_ return laneCoreIds; } -static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) { - IRRewriter rewriter(scalarCore.getContext()); - SmallVector batchOps; - scalarCore.walk([&](Operation* op) { - if (isa(op)) { - batchOps.push_back(op); - } - }); +static void cloneScalarizedLaneBody(OpBuilder& builder, + pim::PimCoreBatchOp coreBatchOp, + unsigned lane, + OperationFolder& constantFolder) { + Block& oldBlock = coreBatchOp.getBody().front(); + size_t laneCount = static_cast(coreBatchOp.getLaneCount()); + size_t weightCount = coreBatchOp.getWeights().size(); - for (Operation* op : batchOps) { - rewriter.setInsertionPoint(op); + IRMapping mapper; + for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) { + if (blockArg.getType().isIndex()) { + mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast(lane), constantFolder)); + continue; + } + + if (argIndex <= weightCount) { + mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]); + continue; + } + + size_t inputIndex = argIndex - 1 - weightCount; + assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range"); + mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]); + } + + for (Operation& op : oldBlock) { + if (isa(op)) + continue; if (auto sendBatchOp = dyn_cast(op)) { - pim::PimSendOp::create(rewriter, - sendBatchOp.getLoc(), - sendBatchOp.getInput(), - sendBatchOp.getSizeAttr(), - rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane])); - rewriter.eraseOp(op); + Operation* anchorOp = builder.getInsertionBlock()->getParentOp(); + pim::PimSendOp::create( + builder, + sendBatchOp.getLoc(), + mapper.lookup(sendBatchOp.getInput()), + sendBatchOp.getSizeAttr(), + getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder)); continue; } if (auto sendTensorBatchOp = dyn_cast(op)) { pim::PimSendTensorOp::create( - rewriter, + builder, sendTensorBatchOp.getLoc(), - sendTensorBatchOp.getInput(), - rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane))); - rewriter.eraseOp(op); + mapper.lookup(sendTensorBatchOp.getInput()), + builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane))); continue; } if (auto receiveBatchOp = dyn_cast(op)) { - auto scalarReceive = - pim::PimReceiveOp::create(rewriter, - receiveBatchOp.getLoc(), - receiveBatchOp.getOutput().getType(), - receiveBatchOp.getOutputBuffer(), - receiveBatchOp.getSizeAttr(), - rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane])); - rewriter.replaceOp(op, scalarReceive->getResults()); + Operation* anchorOp = builder.getInsertionBlock()->getParentOp(); + auto scalarReceive = pim::PimReceiveOp::create( + builder, + receiveBatchOp.getLoc(), + receiveBatchOp.getOutput().getType(), + mapper.lookup(receiveBatchOp.getOutputBuffer()), + receiveBatchOp.getSizeAttr(), + getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder)); + mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput()); continue; } if (auto receiveTensorBatchOp = dyn_cast(op)) { auto scalarReceive = pim::PimReceiveTensorOp::create( - rewriter, + builder, receiveTensorBatchOp.getLoc(), receiveTensorBatchOp.getOutput().getType(), - receiveTensorBatchOp.getOutputBuffer(), - rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane))); - rewriter.replaceOp(op, scalarReceive->getResults()); + mapper.lookup(receiveTensorBatchOp.getOutputBuffer()), + builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane))); + mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput()); continue; } - auto memcpBatchOp = cast(op); - auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter, - memcpBatchOp.getLoc(), - memcpBatchOp.getOutput().getType(), - memcpBatchOp.getDeviceTarget(), - memcpBatchOp.getHostSource(), - memcpBatchOp.getDeviceTargetOffsetAttr(), - memcpBatchOp.getHostSourceOffsetAttr(), - memcpBatchOp.getSizeAttr()); - rewriter.replaceOp(op, scalarCopy->getResults()); + if (auto memcpBatchOp = dyn_cast(op)) { + auto scalarCopy = pim::PimMemCopyHostToDevOp::create( + builder, + memcpBatchOp.getLoc(), + memcpBatchOp.getOutput().getType(), + getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder), + getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder), + mapper.lookup(memcpBatchOp.getDeviceTarget()), + mapper.lookup(memcpBatchOp.getHostSource()), + memcpBatchOp.getSizeAttr()); + mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput()); + continue; + } + + Operation* cloned = builder.clone(op, mapper); + for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(originalResult, clonedResult); } } } // namespace -LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, - unsigned lane, - llvm::function_ref callback) { +LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp, + ArrayRef lanes, + llvm::function_ref callback) { + assert(!lanes.empty() && "expected at least one batch lane"); + OwningOpRef scratchModule = ModuleOp::create(coreBatchOp.getLoc()); OpBuilder builder(scratchModule->getContext()); + OperationFolder constantFolder(scratchModule->getContext()); builder.setInsertionPointToStart(scratchModule->getBody()); - size_t laneCount = static_cast(coreBatchOp.getLaneCount()); - size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount; - SmallVector laneWeights; - laneWeights.reserve(weightsPerLane); - for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) - laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]); - + SmallVector weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end()); auto coreIds = getBatchCoreIds(coreBatchOp); - auto scalarCore = pim::PimCoreOp::create( - builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane])); + int32_t coreId = coreIds[lanes.front()]; + for (unsigned lane : lanes) + assert(coreIds[lane] == coreId && "all grouped lanes must target the same core"); + + auto scalarCore = + pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId)); Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); - IRMapping mapper; - if (coreBatchOp.getBody().front().getNumArguments() == 1) - mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]); - builder.setInsertionPointToEnd(block); - for (Operation& op : coreBatchOp.getBody().front()) { - Operation* cloned = builder.clone(op, mapper); - for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(originalResult, clonedResult); - } - + for (unsigned lane : lanes) + cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder); if (block->empty() || !isa(block->back())) pim::PimHaltOp::create(builder, coreBatchOp.getLoc()); - scalarizeBatchOpsInCore(scalarCore, laneCount, lane); return callback(scalarCore); } +LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, + unsigned lane, + llvm::function_ref callback) { + return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef {lane}, callback); +} + } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimBatchEmission.hpp b/src/PIM/Compiler/PimBatchEmission.hpp index 62c4797..1977d55 100644 --- a/src/PIM/Compiler/PimBatchEmission.hpp +++ b/src/PIM/Compiler/PimBatchEmission.hpp @@ -9,5 +9,8 @@ namespace onnx_mlir { mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane, llvm::function_ref callback); +mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp, + llvm::ArrayRef lanes, + llvm::function_ref callback); } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index ba6f288..64e12df 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -41,15 +41,23 @@ using namespace mlir; using namespace onnx_mlir; using namespace onnx_mlir::compact_asm; +static size_t getElementTypeSizeInBytes(mlir::Type elementType) { + if (elementType.isIndex()) + return sizeof(int64_t); + if (elementType.isIntOrFloat()) + return elementType.getIntOrFloatBitWidth() / 8; + llvm_unreachable("unsupported shaped element type"); +} + static size_t getValueSizeInBytes(mlir::Value value) { auto type = cast(value.getType()); - return type.getNumElements() * type.getElementTypeBitWidth() / 8; + return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()); } MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { auto type = cast(value.getType()); assert("Only static shape is supported" && type.hasStaticShape()); - size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8; + size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()); MemEntry memEntry = {0, allocSize}; return &memEntries.emplace_back(memEntry, value).first; } @@ -398,20 +406,28 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_ } void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const { + auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge); + auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge); + assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset) + && "pim.memcp_hd offsets must be statically resolvable during codegen"); emitMemCopyOp("ld", addressOf(loadOp.getDeviceTarget(), knowledge), - loadOp.getDeviceTargetOffset(), + *deviceTargetOffset, addressOf(loadOp.getHostSource(), knowledge), - loadOp.getHostSourceOffset(), + *hostSourceOffset, loadOp.getSize()); } void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const { + auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge); + auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge); + assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset) + && "pim.memcp_dh offsets must be statically resolvable during codegen"); emitMemCopyOp("st", addressOf(storeOp.getHostTarget(), knowledge), - storeOp.getHostTargetOffset(), + *hostTargetOffset, addressOf(storeOp.getDeviceSource(), knowledge), - storeOp.getDeviceSourceOffset(), + *deviceSourceOffset, storeOp.getSize()); } @@ -426,8 +442,9 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg } void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const { - emitCommunicationOp( - "recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize()); + auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge); + assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen"); + emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize()); } void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, @@ -439,7 +456,9 @@ void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, } void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const { - emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize()); + auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge); + assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen"); + emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize()); } void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const { @@ -728,12 +747,19 @@ std::string getMemorySizeAsString(size_t size) { static SmallVector getUsedWeightIndices(Block& block) { SmallVector indices; - auto addIndex = [&](unsigned weightIndex) { - if (!llvm::is_contained(indices, weightIndex)) - indices.push_back(weightIndex); + auto coreOp = dyn_cast(block.getParentOp()); + auto addWeight = [&](mlir::Value weight) { + if (!coreOp) + return; + for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) { + if (coreOp.getWeightArgument(weightIndex) != weight) + continue; + if (!llvm::is_contained(indices, weightIndex)) + indices.push_back(weightIndex); + return; + } }; - - block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); + block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); }); llvm::sort(indices); return indices; } @@ -795,6 +821,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp, /// fully resolved before the JSON instructions are emitted. /// Returns the number of emitted instructions, or -1 on failure. static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { + auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional { + auto coreOp = vmmOp->getParentOfType(); + if (!coreOp) + return std::nullopt; + for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) + if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) + return weightIndex; + return std::nullopt; + }; size_t processedOperations = 0; auto result = walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) { @@ -814,8 +849,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge); else if (auto concatOp = dyn_cast(op)) coreCodeGen.codeGenConcatOp(concatOp, knowledge); - else if (auto vmmOp = dyn_cast(op)) - coreCodeGen.codeGenMVMLikeOp(vmmOp.getWeightIndex(), vmmOp, true, knowledge); + else if (auto vmmOp = dyn_cast(op)) { + auto weightIndex = resolveWeightIndex(vmmOp); + if (!weightIndex) + return failure(); + coreCodeGen.codeGenMVMLikeOp(*weightIndex, vmmOp, true, knowledge); + } else if (auto transposeOp = dyn_cast(op)) coreCodeGen.codeGenTransposeOp(transposeOp, knowledge); else if (auto vvaddOp = dyn_cast(op)) @@ -1004,10 +1043,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: reportedCoreIds.reserve(batchCoreIds.size()); MemoryReportRow batchRow; std::optional batchPerCoreRow; + llvm::DenseMap> lanesByCoreId; + SmallVector orderedOriginalCoreIds; for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) { + size_t originalCoreId = static_cast(batchCoreIds[lane]); + auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId); + if (inserted) + orderedOriginalCoreIds.push_back(originalCoreId); + it->second.push_back(lane); + } + + for (size_t originalCoreId : orderedOriginalCoreIds) { OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess; - if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) { - size_t originalCoreId = static_cast(batchCoreIds[lane]); + if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) { size_t coreId = emittedCoreIds.lookup(originalCoreId); reportedCoreIds.push_back(static_cast(coreId)); MemoryReportRow laneRow; diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp index 783eea5..94ea43a 100644 --- a/src/PIM/Compiler/PimWeightEmitter.cpp +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -128,12 +128,20 @@ FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value SmallVector getUsedWeightIndices(Block& block) { SmallVector indices; - auto addIndex = [&](unsigned weightIndex) { - if (!llvm::is_contained(indices, weightIndex)) - indices.push_back(weightIndex); + auto coreOp = dyn_cast(block.getParentOp()); + auto addWeight = [&](mlir::Value weight) { + if (!coreOp) + return; + for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) { + if (coreOp.getWeightArgument(weightIndex) != weight) + continue; + if (!llvm::is_contained(indices, weightIndex)) + indices.push_back(weightIndex); + return; + } }; - block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); + block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); }); llvm::sort(indices); return indices; } diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp index 4ffce3e..e78d067 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp @@ -18,13 +18,17 @@ namespace detail { inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } +inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) { + return mlir::ValueRange(block->getArguments()).drop_front(weightCount); +} + template decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { return std::forward(fn)(block->getArgument(Is)...); } template -decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef values, std::index_sequence) { +decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence) { return std::forward(fn)(values[Is]...); } @@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter, auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); + for (mlir::Value weight : weights) + block->addArgument(weight.getType(), loc); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); @@ -93,14 +99,15 @@ auto createSpatCompute(RewriterT& rewriter, using BodyResult = detail::InvokeWithBlockArgsResultT, std::make_index_sequence>; if constexpr (std::is_same_v) { - detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); + detail::invokeWithValues( + std::forward(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence {}); rewriter.setInsertionPointAfter(computeOp); return computeOp; } else { - auto bodyResult = - detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); + auto bodyResult = detail::invokeWithValues( + std::forward(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence {}); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); @@ -123,6 +130,8 @@ auto createSpatCompute(RewriterT& rewriter, auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto* block = new mlir::Block(); + for (mlir::Value weight : weights) + block->addArgument(weight.getType(), loc); for (mlir::Value input : inputs) block->addArgument(input.getType(), loc); @@ -131,13 +140,13 @@ auto createSpatCompute(RewriterT& rewriter, using BodyResult = detail::InvokeWithValueRangeResultT>; if constexpr (std::is_same_v) { - std::forward(body)(detail::getBlockArgs(block)); + std::forward(body)(detail::getInputBlockArgs(block, weights.size())); rewriter.setInsertionPointAfter(computeOp); return computeOp; } else { - auto bodyResult = std::forward(body)(detail::getBlockArgs(block)); + auto bodyResult = std::forward(body)(detail::getInputBlockArgs(block, weights.size())); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index c004bea..87da1e9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -44,7 +44,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); IRMapping mapper; SmallVector computes(funcOp.getOps()); - if (!computes.empty()) + SmallVector computeBatches(funcOp.getOps()); + if (!computes.empty() || !computeBatches.empty()) return; auto returnOp = cast(funcOp.getFunctionBody().front().getTerminator()); @@ -190,16 +191,6 @@ void ONNXToSpatialPass::runOnOperation() { tensor::TensorDialect, arith::ArithDialect, scf::SCFDialect>(); - earlyPostTarget.addDynamicallyLegalOp( - [](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); }); - - RewritePatternSet earlyPostPatterns(ctx); - populateEarlyPostPatterns(earlyPostPatterns, ctx); - if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) { - moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks"); - signalPassFailure(); - return; - } PassManager cleanupPM(ctx); cleanupPM.addPass(createCanonicalizerPass()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 11a8f66..ec79393 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -402,24 +402,37 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) weights.push_back(bTiles[outSliceId][coreId][aSliceId]); - auto computeOp = createSpatCompute( - rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult { - SmallVector vmmOutputs; - vmmOutputs.reserve(aHSlicesArgs.size()); - for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) - vmmOutputs.push_back( - spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); - if (vmmOutputs.empty()) { - gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); - return failure(); - } + auto computeOp = + spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]); + SmallVector blockArgTypes; + SmallVector blockArgLocs; + blockArgTypes.reserve(weights.size() + aHSlices[coreId].size()); + blockArgLocs.reserve(weights.size() + aHSlices[coreId].size()); + for (Value weight : weights) { + blockArgTypes.push_back(weight.getType()); + blockArgLocs.push_back(gemmLoc); + } + for (Value input : aHSlices[coreId]) { + blockArgTypes.push_back(input.getType()); + blockArgLocs.push_back(gemmLoc); + } + Block* body = + rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + rewriter.setInsertionPointToEnd(body); - Value partialVmmSum = sumTensors(vmmOutputs, rewriter); - spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); - return success(); - }); - if (failed(computeOp)) + SmallVector vmmOutputs; + vmmOutputs.reserve(aHSlices[coreId].size()); + for (auto aHSliceId : llvm::seq(0, aHSlices[coreId].size())) + vmmOutputs.push_back(spatial::SpatVMMOp::create( + rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId))); + if (vmmOutputs.empty()) { + gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); return failure(); + } + + Value partialVmmSum = sumTensors(vmmOutputs, rewriter); + spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); + rewriter.setInsertionPointAfter(computeOp); partialResults.push_back(computeOp->getResult(0)); } @@ -530,37 +543,47 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, sharedBias = c; } - SmallVector aSlices = materializeBatchRowSlices(a, aType, rewriter, loc); - auto aSliceType = cast(aSlices.front().getType()); - auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); - SmallVector resultTypes(static_cast(numOutRows), outRowType); - SmallVector weights(static_cast(numOutRows), b); - + auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, - TypeRange(resultTypes), + TypeRange {outType}, rewriter.getI32IntegerAttr(static_cast(numOutRows)), - ValueRange(weights), - ValueRange(aSlices)); + ValueRange {b}, + ValueRange {a}); - Block* body = rewriter.createBlock( - &batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector(1, loc)); + SmallVector blockArgTypes {rewriter.getIndexType(), bType, aType, outType}; + SmallVector blockArgLocs(4, loc); + Block* body = + rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); rewriter.setInsertionPointToEnd(body); - Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult(); + Value lane = batchOp.getLaneArgument(); + Value weight = batchOp.getWeightArgument(0); + Value packedInput = batchOp.getInputArgument(0); + Value packedOutput = batchOp.getOutputArgument(0); + + SmallVector inputOffsets {lane, rewriter.getIndexAttr(0)}; + SmallVector inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; + SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value row = + tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides) + .getResult(); + + Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult(); Value laneResult = vmmResult; if (sharedBias) laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); - spatial::SpatYieldOp::create(rewriter, loc, laneResult); + auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); + rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + SmallVector outputOffsets {lane, rewriter.getIndexAttr(0)}; + SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))}; + tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, + unitStrides); rewriter.setInsertionPointAfter(batchOp); - SmallVector laneResults(batchOp->result_begin(), batchOp->result_end()); - auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) { - spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args)); - }); - rewriter.replaceOp(gemmOp, concatComputeOp); + rewriter.replaceOp(gemmOp, batchOp.getResults()); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp index 09b6459..8b09290 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp @@ -35,58 +35,15 @@ template static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { Block& block = compute.getBody().front(); for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { - if (inputIdx >= block.getNumArguments()) - continue; if (!isWeightLikeComputeOperand(input)) continue; - if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx))) + if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx))) continue; return true; } return false; } -// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily. -struct FoldSingleLaneComputeBatchPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override { - if (batchOp.getLaneCount() != 1) - return rewriter.notifyMatchFailure(batchOp, "requires a single lane"); - - auto loc = batchOp.getLoc(); - rewriter.setInsertionPoint(batchOp); - auto computeOp = - spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs()); - computeOp.getProperties().setOperandSegmentSizes( - {static_cast(batchOp.getWeights().size()), static_cast(batchOp.getInputs().size())}); - - Block& templateBlock = batchOp.getBody().front(); - SmallVector blockArgTypes; - SmallVector blockArgLocs; - blockArgTypes.reserve(templateBlock.getNumArguments()); - blockArgLocs.reserve(templateBlock.getNumArguments()); - for (BlockArgument arg : templateBlock.getArguments()) { - blockArgTypes.push_back(arg.getType()); - blockArgLocs.push_back(loc); - } - - auto* newBlock = - rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - IRMapping mapper; - for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments())) - mapper.map(oldArg, newArg); - - rewriter.setInsertionPointToEnd(newBlock); - for (Operation& op : templateBlock) - rewriter.clone(op, mapper); - - batchOp->replaceAllUsesWith(computeOp->getResults()); - rewriter.eraseOp(batchOp); - return success(); - } -}; - // Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs. struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -96,11 +53,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern= oldBlock.getNumArguments()) - continue; if (!isWeightLikeComputeOperand(input)) continue; - if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx))) + if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx))) continue; promoteInput[inputIdx] = true; needsRewrite = true; @@ -131,8 +86,16 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern newBlockArgTypes; + SmallVector newBlockArgLocs; + for (Value weight : newWeights) { + newBlockArgTypes.push_back(weight.getType()); + newBlockArgLocs.push_back(weight.getLoc()); + } + llvm::append_range(newBlockArgTypes, newInputTypes); + llvm::append_range(newBlockArgLocs, newInputLocs); auto* newBlock = - rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); + rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); newCompute.getProperties().setOperandSegmentSizes( {static_cast(newWeights.size()), static_cast(newInputs.size())}); rewriter.setInsertionPointToStart(newBlock); @@ -141,14 +104,17 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePatterngetArgument(newInputIdx++)); + mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++)); continue; } - auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper); + auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper); if (failed(clonedValue)) return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand"); mapper.map(oldArg, *clonedValue); @@ -180,11 +146,9 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern= oldBlock.getNumArguments()) - continue; if (!isWeightLikeComputeOperand(input)) continue; - if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx))) + if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx))) continue; promoteInput[inputIdx] = true; needsRewrite = true; @@ -220,8 +184,25 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(compute.getLaneCount())), newWeights, newInputs); - auto* newBlock = - rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); + SmallVector newBlockArgTypes; + SmallVector newBlockArgLocs; + newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults()); + newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults()); + newBlockArgTypes.push_back(compute.getLaneArgument().getType()); + newBlockArgLocs.push_back(compute.getLaneArgument().getLoc()); + for (Value weight : newWeights) { + newBlockArgTypes.push_back(weight.getType()); + newBlockArgLocs.push_back(weight.getLoc()); + } + llvm::append_range(newBlockArgTypes, newInputTypes); + llvm::append_range(newBlockArgLocs, newInputLocs); + for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) { + newBlockArgTypes.push_back(resultType); + newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc()); + } + + auto* newBlock = rewriter.createBlock( + &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); newCompute.getProperties().setOperandSegmentSizes( {static_cast(newWeights.size()), static_cast(newInputs.size())}); rewriter.setInsertionPointToStart(newBlock); @@ -230,31 +211,28 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePatterngetArgument(newInputIdx++)); + mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++)); continue; } - auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper); + auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper); if (failed(clonedValue)) return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand"); mapper.map(oldArg, *clonedValue); } + for (auto resultIndex : llvm::seq(0, compute.getNumResults())) + mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex)); - for (Operation& op : oldBlock.without_terminator()) + for (Operation& op : oldBlock) rewriter.clone(op, mapper); - auto oldYield = cast(oldBlock.getTerminator()); - SmallVector newYieldOperands; - newYieldOperands.reserve(oldYield.getOutputs().size()); - for (Value operand : oldYield.getOutputs()) { - auto mapped = mapper.lookupOrNull(operand); - newYieldOperands.push_back(mapped ? cast(mapped) : operand); - } - spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands); - rewriter.replaceOp(compute, newCompute.getResults()); return success(); } @@ -262,10 +240,6 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(ctx); -} - void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } @@ -277,8 +251,6 @@ void annotateWeightsConstants(func::FuncOp funcOp) { }); } -bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; } - bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); } bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); } diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp index 6a1b4bd..8c14f7f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp @@ -7,14 +7,10 @@ namespace onnx_mlir { -bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp); - bool requiresPostRewrite(spatial::SpatCompute computeOp); bool requiresPostRewrite(spatial::SpatComputeBatch computeOp); -void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void annotateWeightsConstants(mlir::func::FuncOp funcOp); diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 58a0c66..bb3f1da 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" @@ -17,6 +18,37 @@ namespace { static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } +static FailureOr getConstantI32Value(Value value) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + return static_cast(constantValue.getSExtValue()); +} + +static FailureOr> getConstantI32Values(ValueRange values) { + SmallVector constants; + constants.reserve(values.size()); + for (Value value : values) { + FailureOr constantValue = getConstantI32Value(value); + if (failed(constantValue)) + return failure(); + constants.push_back(*constantValue); + } + return constants; +} + +static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { + if (isa(op)) + return operandIndex == 2; + return false; +} + +static bool isUsedOnlyAsExplicitHostOperand(Value value) { + return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) { + return isExplicitHostOperand(use.getOwner(), use.getOperandNumber()); + }); +} + static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); @@ -28,27 +60,30 @@ static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co return coreIds; } -static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp, - IRMapping& mapper, - IRRewriter& rewriter) { - SmallVector targetCoreIds; - targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size()); - for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds()) - targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); +static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp, + IRMapping& mapper, + IRRewriter& rewriter) { + FailureOr> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds()); + if (failed(targetCoreIds)) + return sendTensorBatchOp.emitOpError("expected constant targetCoreIds"); + for (int32_t& targetCoreId : *targetCoreIds) + targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId); pim::PimSendTensorBatchOp::create(rewriter, sendTensorBatchOp.getLoc(), mapper.lookup(sendTensorBatchOp.getInput()), - rewriter.getDenseI32ArrayAttr(targetCoreIds)); + rewriter.getDenseI32ArrayAttr(*targetCoreIds)); + return success(); } -static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp, - IRMapping& mapper, - IRRewriter& rewriter) { - SmallVector sourceCoreIds; - sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size()); - for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds()) - sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); +static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp, + IRMapping& mapper, + IRRewriter& rewriter) { + FailureOr> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds()); + if (failed(sourceCoreIds)) + return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds"); + for (int32_t& sourceCoreId : *sourceCoreIds) + sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId); auto outputType = cast(receiveTensorBatchOp.getOutput().getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType); @@ -56,24 +91,26 @@ static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatc receiveTensorBatchOp.getLoc(), outputBuffer.getType(), outputBuffer, - rewriter.getDenseI32ArrayAttr(sourceCoreIds)) + rewriter.getDenseI32ArrayAttr(*sourceCoreIds)) .getOutput(); mapper.map(receiveTensorBatchOp.getOutput(), received); + return success(); } } // namespace LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) { - if (computeBatchOp.getNumResults() != 0) - return computeBatchOp.emitOpError( - "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results"); - Location loc = computeBatchOp.getLoc(); Block& oldBlock = computeBatchOp.getBody().front(); - auto oldYield = cast(oldBlock.getTerminator()); - if (oldYield.getNumOperands() != 0) - return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield"); + if (computeBatchOp.getNumResults() != 0) + return computeBatchOp.emitOpError( + "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; " + "materialize explicit communication before lowering to PIM"); + + auto oldYield = dyn_cast(oldBlock.getTerminator()); + if (!oldYield || oldYield.getNumOperands() != 0) + return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield"); SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId); SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); @@ -102,7 +139,12 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& IRMapping mapper; rewriter.setInsertionPointToStart(newBlock); - for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) { + mapper.map(computeBatchOp.getLaneArgument(), coreBatchOp.getLaneArgument()); + for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) + mapper.map(computeBatchOp.getWeightArgument(weightIndex), coreBatchOp.getWeightArgument(weightIndex)); + for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) { + BlockArgument oldArg = computeBatchOp.getInputArgument(inputIndex); + BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex); auto newArgType = cast(newArg.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, @@ -142,20 +184,31 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& continue; if (auto sendBatchOp = dyn_cast(op)) { + FailureOr> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds()); + if (failed(targetCoreIds)) + return sendBatchOp.emitOpError("expected constant targetCoreIds"); + for (int32_t& targetCoreId : *targetCoreIds) + targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId); pim::PimSendBatchOp::create(rewriter, loc, mapper.lookup(sendBatchOp.getInput()), getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())), - sendBatchOp.getTargetCoreIdsAttr()); + rewriter.getDenseI32ArrayAttr(*targetCoreIds)); continue; } if (auto sendTensorBatchOp = dyn_cast(op)) { - lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter); + if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter))) + return failure(); continue; } if (auto receiveBatchOp = dyn_cast(op)) { + FailureOr> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds()); + if (failed(sourceCoreIds)) + return receiveBatchOp.emitOpError("expected constant sourceCoreIds"); + for (int32_t& sourceCoreId : *sourceCoreIds) + sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId); auto outputType = cast(receiveBatchOp.getOutput().getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); auto received = pim::PimReceiveBatchOp::create(rewriter, @@ -163,14 +216,15 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& outputBuffer.getType(), outputBuffer, getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()), - receiveBatchOp.getSourceCoreIdsAttr()) + rewriter.getDenseI32ArrayAttr(*sourceCoreIds)) .getOutput(); mapper.map(receiveBatchOp.getOutput(), received); continue; } if (auto receiveTensorBatchOp = dyn_cast(op)) { - lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter); + if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter))) + return failure(); continue; } @@ -178,6 +232,10 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& if (isa_and_present(toTensorOp.getBuffer().getDefiningOp())) { Operation* cloned = rewriter.clone(op, mapper); auto clonedTensor = cloned->getResult(0); + if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) { + mapper.map(toTensorOp.getResult(), clonedTensor); + continue; + } auto clonedType = cast(clonedTensor.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, @@ -194,9 +252,11 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& } } - for (Value operand : op.getOperands()) { + for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { if (!isa(operand.getType()) || mapper.contains(operand)) continue; + if (isExplicitHostOperand(&op, operandIndex)) + continue; Operation* definingOp = operand.getDefiningOp(); if (definingOp && definingOp->getBlock() == &oldBlock) diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index a0bb5d2..1d004ca 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -22,6 +22,8 @@ add_pim_library(OMSpatialToPim LINK_LIBS PUBLIC MLIRSCFDialect + MLIRSCFUtils + MLIRTransformUtils MLIRTosaDialect OMCompilerOptions OMPimCommon diff --git a/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp index bb526be..88de311 100644 --- a/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" @@ -12,15 +13,24 @@ namespace { static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; } +static FailureOr> getConstantI32Values(ValueRange values) { + SmallVector constants; + constants.reserve(values.size()); + for (Value value : values) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + constants.push_back(static_cast(constantValue.getSExtValue())); + } + return constants; +} + struct ChannelSendLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override { - pim::PimSendOp::create(rewriter, - op.getLoc(), - op.getInput(), - getTensorSizeInBytesAttr(rewriter, op.getInput()), - rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId()))); + pim::PimSendOp::create( + rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId()); rewriter.eraseOp(op); return success(); } @@ -42,7 +52,7 @@ struct ChannelReceiveLowering : OpRewritePattern op.getResult().getType(), outputBuffer, getTensorSizeInBytesAttr(rewriter, op.getResult()), - rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId()))) + op.getSourceCoreId()) .getOutput(); rewriter.replaceOp(op, received); return success(); @@ -53,11 +63,12 @@ struct ChannelSendTensorLowering : OpRewritePattern targetCoreIds; - targetCoreIds.reserve(op.getTargetCoreIds().size()); - for (int32_t targetCoreId : op.getTargetCoreIds()) - targetCoreIds.push_back(toPimCoreId(targetCoreId)); - pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds)); + FailureOr> targetCoreIds = getConstantI32Values(op.getTargetCoreIds()); + if (failed(targetCoreIds)) + return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds"); + for (int32_t& targetCoreId : *targetCoreIds) + targetCoreId = toPimCoreId(targetCoreId); + pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds)); rewriter.eraseOp(op); return success(); } @@ -67,16 +78,17 @@ struct ChannelReceiveTensorLowering : OpRewritePattern sourceCoreIds; - sourceCoreIds.reserve(op.getSourceCoreIds().size()); - for (int32_t sourceCoreId : op.getSourceCoreIds()) - sourceCoreIds.push_back(toPimCoreId(sourceCoreId)); + FailureOr> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds()); + if (failed(sourceCoreIds)) + return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds"); + for (int32_t& sourceCoreId : *sourceCoreIds) + sourceCoreId = toPimCoreId(sourceCoreId); auto outputType = cast(op.getOutput().getType()); Value outputBuffer = tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); Value received = pim::PimReceiveTensorOp::create( - rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds)) + rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds)) .getOutput(); rewriter.replaceOp(op, received); return success(); diff --git a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp index a94a7b6..bc1e577 100644 --- a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp +++ b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp @@ -29,7 +29,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter, unsigned inputIndex, Value replacement) { Block& body = owner->getRegion(0).front(); - BlockArgument bodyArgument = body.getArgument(inputIndex); + BlockArgument bodyArgument = isa(owner) + ? cast(owner).getInputArgument(inputIndex) + : cast(owner).getInputArgument(inputIndex); + unsigned bodyArgIndex = bodyArgument.getArgNumber(); rewriter.startOpModification(owner); bodyArgument.replaceAllUsesWith(replacement); @@ -37,7 +40,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter, compute.getInputsMutable().erase(inputIndex); else cast(owner).getInputsMutable().erase(inputIndex); - body.eraseArgument(inputIndex); + body.eraseArgument(bodyArgIndex); rewriter.finalizeOpModification(owner); } diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index a1cba21..bf33989 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -3,6 +3,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" @@ -27,7 +28,8 @@ static bool isChannelUseChainOp(Operation* op) { pim::PimTransposeOp>(op); } -static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { +static void +cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) { for (Value operand : op->getOperands()) { if (mapping.lookupOrNull(operand)) continue; @@ -36,7 +38,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri if (!definingOp) continue; - if (!isa(definingOp)) + if (auto constantOp = dyn_cast(definingOp)) { + mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder)); + continue; + } + + if (!isa(definingOp)) continue; Operation* clonedOp = rewriter.clone(*definingOp, mapping); @@ -48,6 +55,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } +static FailureOr> getConstantI32Values(ValueRange values) { + SmallVector constants; + constants.reserve(values.size()); + for (Value value : values) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + constants.push_back(static_cast(constantValue.getSExtValue())); + } + return constants; +} + static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return static_cast(spatialCoreIdAttr.getInt()); @@ -92,7 +111,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, return success(); } -static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { +static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, + IRRewriter& rewriter, + OperationFolder& constantFolder) { if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) return false; if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { @@ -101,7 +122,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute return false; Block& block = computeOp.getBody().front(); - if (block.getNumArguments() != 0) + if (block.getNumArguments() != computeOp.getWeights().size()) return false; auto yieldOp = dyn_cast(block.getTerminator()); @@ -110,8 +131,10 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute rewriter.setInsertionPoint(computeOp); IRMapping mapping; + for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) + mapping.map(computeOp.getWeightArgument(weightIndex), weight); for (Operation& op : block.without_terminator()) { - cloneMappedHelperOperands(&op, mapping, rewriter); + cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder); Operation* clonedOp = rewriter.clone(op, mapping); for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); @@ -133,7 +156,7 @@ void markOpToRemove(CoreLoweringState& state, Operation* op) { LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); - if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter)) + if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder)) return success(); SmallVector helperChain; @@ -143,21 +166,42 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& auto& block = computeOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); - for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) { - auto receiveOp = dyn_cast_or_null(computeOp.getInputs()[argIndex].getDefiningOp()); - if (!receiveOp || blockArg.use_empty()) + for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { + BlockArgument blockArg = computeOp.getInputArgument(inputIndex); + auto receiveOp = dyn_cast_or_null(input.getDefiningOp()); + if (receiveOp && !blockArg.use_empty()) { + rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); + auto outputType = cast(blockArg.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); + Value received = + PimReceiveOp::create( + rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId()) + .getOutput(); + blockArg.replaceAllUsesWith(received); + markOpToRemove(state, receiveOp); continue; + } - rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); - auto outputType = cast(blockArg.getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); - auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); - Value received = PimReceiveOp::create( - rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) - .getOutput(); - blockArg.replaceAllUsesWith(received); - markOpToRemove(state, receiveOp); + auto receiveTensorOp = dyn_cast_or_null(input.getDefiningOp()); + if (receiveTensorOp && !blockArg.use_empty()) { + FailureOr> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds()); + if (failed(sourceCoreIds)) + return receiveTensorOp.emitOpError("expected constant sourceCoreIds"); + for (int32_t& sourceCoreId : *sourceCoreIds) + sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId); + rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); + auto outputType = cast(blockArg.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType); + Value received = PimReceiveTensorOp::create(rewriter, + receiveTensorOp.getLoc(), + outputBuffer.getType(), + outputBuffer, + rewriter.getDenseI32ArrayAttr(*sourceCoreIds)) + .getOutput(); + blockArg.replaceAllUsesWith(received); + markOpToRemove(state, receiveTensorOp); + } } if (computeOp.getNumResults() != yieldOp.getNumOperands()) @@ -197,11 +241,36 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId))); + rewriter.setInsertionPointToStart(&block); auto& coreOpBlocks = coreOp.getBody().getBlocks(); - for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) - if (!blockArg.use_empty()) - blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]); - block.eraseArguments(0, block.getNumArguments()); + for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { + BlockArgument blockArg = computeOp.getInputArgument(inputIndex); + if (blockArg.use_empty()) + continue; + + if (auto constantOp = input.getDefiningOp()) { + blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder)); + continue; + } + + auto inputType = dyn_cast(input.getType()); + if (!inputType) + return computeOp.emitOpError("expected shaped compute input during pim.core lowering"); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType); + auto copied = + PimMemCopyHostToDevOp::create(rewriter, + loc, + outputBuffer.getType(), + getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder), + getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder), + outputBuffer, + input, + getTensorSizeInBytesAttr(rewriter, input)) + .getOutput(); + blockArg.replaceAllUsesWith(copied); + } + if (!computeOp.getInputs().empty()) + block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); Block* tempComputeBlock = new Block(); computeOp.getBody().push_back(tempComputeBlock); diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp index 74304ed..7e7d214 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp @@ -1,6 +1,7 @@ #pragma once #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/FoldUtils.h" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -11,6 +12,7 @@ struct CoreLoweringState { size_t& nextCoreId; llvm::SmallVectorImpl& outputTensors; llvm::SmallVectorImpl& operationsToRemove; + mlir::OperationFolder& constantFolder; }; void markOpToRemove(CoreLoweringState& state, mlir::Operation* op); diff --git a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp index 2e6b456..7caf4ed 100644 --- a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp +++ b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp @@ -76,8 +76,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern(uses.getOwner())) { auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber()); if (!inputIndex) return failure(); - auto BBArgIndex = *inputIndex; - auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); + auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex); if (BBArgValue.use_empty()) continue; @@ -108,7 +106,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern mapSpatComputeToConst; + Value hostConstant = constantOp.getResult(); for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { auto constUsers = constUses.getOwner(); @@ -264,40 +262,22 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatterngetResult(0)); + replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant); } else if (auto spatComputeBatch = llvm::dyn_cast(constUsers)) { auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; - rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); - auto newConst = rewriter.clone(*constantOp); - - replaceAndEraseDirectComputeLikeInput( - rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0)); + replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant); } - else if (auto parent = constUsers->getParentOfType()) { - if (!mapSpatComputeToConst.contains(parent)) { - rewriter.setInsertionPoint(&parent.getBody().front().front()); - auto newConst = rewriter.clone(*constantOp); - mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)}); - } - constUses.set(mapSpatComputeToConst[parent.getOperation()]); + else if (constUsers->getParentOfType()) { + constUses.set(hostConstant); } else { auto batchParent = constUsers->getParentOfType(); assert(batchParent && "Global Constant used direcly not within a compute"); - if (!mapSpatComputeToConst.contains(batchParent.getOperation())) { - rewriter.setInsertionPoint(&batchParent.getBody().front().front()); - auto newConst = rewriter.clone(*constantOp); - mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)}); - } - constUses.set(mapSpatComputeToConst[batchParent.getOperation()]); + constUses.set(hostConstant); } } } diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 3a5f755..681bf1f 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -6,8 +6,10 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/FoldUtils.h" #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -318,7 +320,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndice return success(); } -static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { +static void +cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) { for (Value operand : op->getOperands()) { if (mapping.lookupOrNull(operand)) continue; @@ -327,7 +330,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri if (!definingOp) continue; - if (!isa(definingOp)) + if (auto constantOp = dyn_cast(definingOp)) { + mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder)); + continue; + } + + if (!isa(definingOp)) continue; Operation* clonedOp = rewriter.clone(*definingOp, mapping); @@ -337,15 +345,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri } } -static void -cloneHelperChain(Value sourceValue, ArrayRef helperChain, IRRewriter& rewriter, Value& clonedValue) { +static void cloneHelperChain(Value sourceValue, + ArrayRef helperChain, + IRRewriter& rewriter, + OperationFolder& constantFolder, + Value& clonedValue) { IRMapping mapping; mapping.map(sourceValue, sourceValue); clonedValue = sourceValue; rewriter.setInsertionPointAfterValue(sourceValue); for (Operation* op : helperChain) { - cloneMappedHelperOperands(op, mapping, rewriter); + cloneMappedHelperOperands(op, mapping, rewriter, constantFolder); Operation* clonedOp = rewriter.clone(*op, mapping); for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); @@ -360,14 +371,19 @@ static Value emitHostCopy(IRRewriter& rewriter, Value sourceValue, int32_t hostTargetOffset, int32_t deviceSourceOffset, - int32_t sizeInBytes) { + int32_t sizeInBytes, + OperationFolder& constantFolder) { + Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp(); + assert(anchorOp && "expected a concrete op anchor for return-path host copy constants"); + Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder); + Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder); return PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), + hostTargetOffsetValue, + deviceSourceOffsetValue, outputTensor, sourceValue, - rewriter.getI32IntegerAttr(hostTargetOffset), - rewriter.getI32IntegerAttr(deviceSourceOffset), rewriter.getI32IntegerAttr(sizeInBytes)) .getOutput(); } @@ -411,69 +427,84 @@ void addReturnOutputBuffers(func::ReturnOp returnOp, } } -ReturnPathLoweringResult lowerComputeResultReturnPath( - spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) { - Location loc = computeOp->getLoc(); - auto yieldType = cast(yieldValue.getType()); +ReturnPathLoweringResult lowerProducedValueReturnPath( + Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) { + Location loc = producerOp->getLoc(); + OperationFolder constantFolder(producerOp->getContext()); + auto storedTensorType = cast(storedValue.getType()); - if (auto returnUse = analyzeReturnUse(result)) { - Value storedValue = yieldValue; - cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue); + if (auto returnUse = analyzeReturnUse(producedValue)) { + Value currentStoredValue = storedValue; + cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue); for (Operation* op : returnUse->helperChain) markOpToRemove(state, op); - auto storedType = cast(storedValue.getType()); + auto storedType = cast(currentStoredValue.getType()); size_t elementSize = storedType.getElementTypeBitWidth() / 8; - if (auto storedOp = storedValue.getDefiningOp()) + if (auto storedOp = currentStoredValue.getDefiningOp()) rewriter.setInsertionPointAfter(storedOp); Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc); - emitHostCopy( - rewriter, loc, outputTensor, storedValue, 0, 0, static_cast(storedType.getNumElements() * elementSize)); + emitHostCopy(rewriter, + loc, + outputTensor, + currentStoredValue, + 0, + 0, + static_cast(storedType.getNumElements() * elementSize), + constantFolder); return ReturnPathLoweringResult::Handled; } - auto resultUses = result.getUses(); + auto resultUses = producedValue.getUses(); if (rangeLength(resultUses) == 1) { OpOperand& resultUse = *resultUses.begin(); Operation* resultUser = resultUse.getOwner(); if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); - size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; - rewriter.setInsertionPointAfterValue(yieldValue); + size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8; + rewriter.setInsertionPointAfterValue(storedValue); Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc); - emitHostCopy( - rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast(yieldType.getNumElements() * elementSize)); + emitHostCopy(rewriter, + loc, + outputTensor, + storedValue, + 0, + 0, + static_cast(storedTensorType.getNumElements() * elementSize), + constantFolder); return ReturnPathLoweringResult::Handled; } } - if (auto concatReturnUse = analyzeConcatReturnUse(result)) { - size_t elementSize = yieldType.getElementTypeBitWidth() / 8; + if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) { + size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8; for (Operation* concatOp : concatReturnUse->concatChain) markOpToRemove(state, concatOp); if (concatReturnUse->helperChain.empty()) { - rewriter.setInsertionPointAfterValue(yieldValue); + rewriter.setInsertionPointAfterValue(storedValue); Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); emitHostCopy(rewriter, loc, outputTensor, - yieldValue, + storedValue, static_cast(flatOffset * elementSize), 0, - static_cast(yieldType.getNumElements() * elementSize)); + static_cast(storedTensorType.getNumElements() * elementSize), + constantFolder); return ReturnPathLoweringResult::Handled; } - auto storedType = dyn_cast(yieldValue.getType()); + auto storedType = dyn_cast(storedValue.getType()); if (!storedType) { - computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); + producerOp->emitOpError( + "has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); return ReturnPathLoweringResult::Failure; } - rewriter.setInsertionPointAfterValue(yieldValue); + rewriter.setInsertionPointAfterValue(storedValue); Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); auto outputType = cast(outputTensor.getType()); for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { @@ -484,7 +515,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath( SmallVector destinationIndices; if (failed(mapIndicesThroughHelperChain( sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) { - computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); + producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); return ReturnPathLoweringResult::Failure; } @@ -503,7 +534,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath( auto scalarTensorType = RankedTensorType::get(SmallVector(storedType.getRank(), 1), storedType.getElementType()); auto elementSlice = tensor::ExtractSliceOp::create( - rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides); + rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides); rewriter.setInsertionPointAfter(elementSlice); int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); @@ -513,7 +544,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath( elementSlice.getResult(), static_cast(destinationFlatOffset * elementSize), 0, - static_cast(elementSize)); + static_cast(elementSize), + constantFolder); } return ReturnPathLoweringResult::Handled; } @@ -521,6 +553,11 @@ ReturnPathLoweringResult lowerComputeResultReturnPath( return ReturnPathLoweringResult::NotReturnPath; } +ReturnPathLoweringResult lowerComputeResultReturnPath( + spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) { + return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter); +} + void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) { auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { if (!op) @@ -569,7 +606,16 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite markOpToRemove(state, concatOp); for (Value operand : concatOp.getInputs()) markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); + return; } + + if (auto receiveOp = dyn_cast(op)) { + markOpToRemove(state, receiveOp); + return; + } + + if (auto receiveTensorOp = dyn_cast(op)) + markOpToRemove(state, receiveTensorOp); }; SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp index 6a1c78c..fe86724 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp @@ -32,6 +32,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute compu ReturnPathState& state, mlir::IRRewriter& rewriter); +ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp, + mlir::Value producedValue, + mlir::Value storedValue, + ReturnPathState& state, + mlir::IRRewriter& rewriter); + void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index d79dd66..4d35f5e 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -16,8 +16,8 @@ def onnxToPimTranspose : Pat< >; def spatToPimVMM : Pat< - (SpatVMMOp:$srcOpRes $weightIndex, $vector), - (PimVMMOp $weightIndex, $vector, + (SpatVMMOp:$srcOpRes $weight, $vector), + (PimVMMOp $weight, $vector, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 56176d7..7957262 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -3,6 +3,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" @@ -12,6 +13,8 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" @@ -104,23 +107,34 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc IntegerAttr {}); } -static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { +static Value createZeroedDeviceHVector(IRRewriter& rewriter, + Location loc, + RankedTensorType tensorType, + OperationFolder& constantFolder) { auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); - auto zeroAttr = rewriter.getI32IntegerAttr(0); + auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(tensorType))); if (outputBuffer->getParentOfType()) - return PimMemCopyHostToDevBatchOp::create( - rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr) + return PimMemCopyHostToDevBatchOp::create(rewriter, + loc, + tensorType, + outputBuffer, + zeroValue, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + sizeAttr) .getOutput(); - return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr) + return PimMemCopyHostToDevOp::create( + rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr) .getOutput(); } -static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) { +static Value +padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) { auto vectorType = cast(vector.getType()); ArrayRef shape = vectorType.getShape(); assert(isHVectorShape(shape) && "expected a horizontal vector"); @@ -131,7 +145,7 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V auto paddedType = RankedTensorType::get( {shape[0], static_cast(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); - Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType); + Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder); auto zeroAttr = rewriter.getI32IntegerAttr(0); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(vectorType))); return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput(); @@ -151,6 +165,7 @@ void SpatialToPimPass::runOnOperation() { func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); + OperationFolder constantFolder(&getContext()); ConversionTarget target(*ctx); target.addLegalDialect(funcOp.front().getTerminator()); + auto returnOp = cast(funcOp.front().getTerminator()); addReturnOutputBuffers(returnOp, rewriter, outputTensors); + ReturnPathState returnPathState {outputTensors, operationsToRemove}; if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering"); signalPassFailure(); return; } - CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove}; + CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder}; for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) { @@ -251,7 +267,6 @@ void SpatialToPimPass::runOnOperation() { } enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); - ReturnPathState returnPathState {outputTensors, operationsToRemove}; replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState); SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); @@ -302,6 +317,7 @@ void SpatialToPimPass::runOnOperation() { } void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { + OperationFolder constantFolder(funcOp.getContext()); funcOp.walk([&](PimVMMOp vmmOp) { auto outputType = cast(vmmOp.getOutput().getType()); ArrayRef outputShape = outputType.getShape(); @@ -309,7 +325,7 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I assert(outputShape[1] <= static_cast(crossbarSize) && "output width must fit in one crossbar"); rewriter.setInsertionPoint(vmmOp); - Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput()); + Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder); auto paddedOutputType = RankedTensorType::get( {outputShape[0], static_cast(crossbarSize)}, outputType.getElementType(), outputType.getEncoding()); Value paddedOutputBuffer = outputShape[1] == static_cast(crossbarSize) @@ -336,10 +352,13 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); + OperationFolder constantFolder(funcOp.getContext()); auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); + if (!elementType.isIntOrFloat()) + return; size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); @@ -349,10 +368,11 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu rewriter, loc, tensorType, + getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder), + getOrCreateHostIndexConstant( + deviceTensor.getOperation(), static_cast(elementsOffset * elementByteSize), constantFolder), deviceTensor, inputTensor, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(static_cast(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 07022a2..2ade4fd 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -2,6 +2,7 @@ #define PIM_DIALECT_H include "mlir/IR/OpBase.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/IR/AttrTypeBase.td" include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -24,7 +25,8 @@ def PimTensor : // Execution //===----------------------------------------------------------------------===// -def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> { +def PimCoreOp : PimOp<"core", [SingleBlock, + DeclareOpInterfaceMethods]> { let summary = "Execute a block on a PIM core"; let regions = (region SizedRegion<1>:$body); @@ -34,12 +36,16 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> { I32Attr:$coreId ); - let assemblyFormat = [{ - `(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)` + let extraClassDeclaration = [{ + ::mlir::BlockArgument getWeightArgument(unsigned idx); }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } -def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> { +def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { let summary = "Execute equivalent batched core bodies"; let regions = (region SizedRegion<1>:$body); @@ -50,6 +56,13 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi Variadic:$inputs ); + let extraClassDeclaration = [{ + ::mlir::BlockArgument getLaneArgument(); + ::mlir::BlockArgument getWeightArgument(unsigned idx); + ::mlir::BlockArgument getInputArgument(unsigned idx); + }]; + + let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } @@ -81,11 +94,11 @@ def PimSendOp : PimOp<"send", []> { let arguments = (ins PimTensor:$input, I32Attr:$size, - I32Attr:$targetCoreId + Index:$targetCoreId ); let assemblyFormat = [{ - `(` $input `)` attr-dict `:` type($input) `->` `(` `)` + `(` $input `,` $targetCoreId `)` attr-dict `:` type($input) `->` `(` `)` }]; } @@ -131,7 +144,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { let arguments = (ins PimTensor:$outputBuffer, I32Attr:$size, - I32Attr:$sourceCoreId + Index:$sourceCoreId ); let results = (outs @@ -145,7 +158,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { }]; let assemblyFormat = [{ - `(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output) + `(` $outputBuffer `,` $sourceCoreId `)` attr-dict `:` type($outputBuffer) `->` type($output) }]; } @@ -219,10 +232,10 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { let summary = "Copy a memory region from host memory into device memory"; let arguments = (ins + Index:$deviceTargetOffset, + Index:$hostSourceOffset, PimTensor:$deviceTarget, PimTensor:$hostSource, - I32Attr:$deviceTargetOffset, - I32Attr:$hostSourceOffset, I32Attr:$size ); @@ -237,7 +250,9 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { }]; let assemblyFormat = [{ - `(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output) + `[` $deviceTargetOffset `,` $hostSourceOffset `]` + `(` $deviceTarget `,` $hostSource `)` attr-dict + `:` type($deviceTarget) `,` type($hostSource) `->` type($output) }]; } @@ -271,10 +286,10 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> { let summary = "Copy a memory region from device memory into host memory"; let arguments = (ins + Index:$hostTargetOffset, + Index:$deviceSourceOffset, PimTensor:$hostTarget, PimTensor:$deviceSource, - I32Attr:$hostTargetOffset, - I32Attr:$deviceSourceOffset, I32Attr:$size ); @@ -289,7 +304,9 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> { }]; let assemblyFormat = [{ - `(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output) + `[` $hostTargetOffset `,` $deviceSourceOffset `]` + `(` $hostTarget `,` $deviceSource `)` attr-dict + `:` type($hostTarget) `,` type($deviceSource) `->` type($output) }]; } @@ -374,7 +391,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> { let summary = "Vector-matrix multiplication: c = a * b"; let arguments = (ins - I32Attr:$weightIndex, + PimTensor:$weight, PimTensor:$input, PimTensor:$outputBuffer ); @@ -391,7 +408,8 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> { let hasVerifier = 1; let assemblyFormat = [{ - `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + `[` $weight `]` `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($weight) `,` type($input) `,` + type($outputBuffer) `)` `->` type($output) }]; } diff --git a/src/PIM/Dialect/Pim/PimOps.cpp b/src/PIM/Dialect/Pim/PimOps.cpp index 5168fda..57bb1bc 100644 --- a/src/PIM/Dialect/Pim/PimOps.cpp +++ b/src/PIM/Dialect/Pim/PimOps.cpp @@ -1,8 +1,41 @@ #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include + +using namespace mlir; + namespace onnx_mlir { namespace pim { +BlockArgument PimCoreOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); } + +void PimCoreOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { + if (region.empty()) + return; + + for (unsigned index = 0; index < getWeights().size(); ++index) + setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); +} + +BlockArgument PimCoreBatchOp::getLaneArgument() { return getBody().front().getArgument(0); } + +BlockArgument PimCoreBatchOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); } + +BlockArgument PimCoreBatchOp::getInputArgument(unsigned idx) { + return getBody().front().getArgument(1 + getWeights().size() + idx); +} + +void PimCoreBatchOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { + if (region.empty()) + return; + + setNameFn(getLaneArgument(), "lane"); + for (unsigned index = 0; index < getWeights().size(); ++index) + setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); + for (unsigned index = 0; index < getInputs().size(); ++index) + setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); +} + void PimDialect::initialize() { addOperations< #define GET_OP_LIST diff --git a/src/PIM/Dialect/Pim/PimOpsAsm.cpp b/src/PIM/Dialect/Pim/PimOpsAsm.cpp index db2f6d0..283773f 100644 --- a/src/PIM/Dialect/Pim/PimOpsAsm.cpp +++ b/src/PIM/Dialect/Pim/PimOpsAsm.cpp @@ -20,6 +20,80 @@ static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef arguments) { + printer << "("; + for (auto [index, argument] : llvm::enumerate(arguments)) { + if (index != 0) + printer << ", "; + printer.printOperand(argument); + } + printer << ")"; +} + +static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl& arguments) { + if (parser.parseLParen()) + return failure(); + if (succeeded(parser.parseOptionalRParen())) + return success(); + + OpAsmParser::Argument argument; + if (parser.parseArgument(argument)) + return failure(); + arguments.push_back(argument); + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseArgument(argument)) + return failure(); + arguments.push_back(argument); + } + return parser.parseRParen(); +} + +static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) { + printCompressedValueList(printer, arguments, delimiter); + printer << " = "; + printCompressedValueList(printer, operands, delimiter); +} + +static ParseResult parseBoundValueList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& arguments, + SmallVectorImpl& operands) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) { + if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) + return failure(); + return success(); + } + + if (parseOneCompressedArgumentEntry(parser, arguments)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedArgumentEntry(parser, arguments)) + return failure(); + + auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { + switch (currentDelimiter) { + case ListDelimiter::Paren: + return parser.parseRParen(); + case ListDelimiter::Square: + return parser.parseRSquare(); + } + llvm_unreachable("unsupported delimiter"); + }; + if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) + return failure(); + return success(); +} + static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef coreIds) { printer << " " << keyword << " "; printCompressedIntegerList(printer, coreIds); @@ -33,15 +107,76 @@ static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keywor } // namespace -void PimCoreBatchOp::print(OpAsmPrinter& printer) { - printer << " lanes " << getLaneCount() << " "; - size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast(getLaneCount()) : 0; - if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane)) - printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren); - else - printCompressedValueList(printer, getWeights(), ListDelimiter::Paren); +void PimCoreOp::print(OpAsmPrinter& printer) { + SmallVector weightArgs; + weightArgs.reserve(getWeights().size()); + for (unsigned index = 0; index < getWeights().size(); ++index) + weightArgs.push_back(getWeightArgument(index)); + printer << " "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Square); + printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); + printer << " coreId " << getCoreId(); + printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()}); + printer << " : "; + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); + printer << " -> () "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector weightArgs; + SmallVector weights; + SmallVector weightTypes; + int32_t coreId = 0; + + if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) + return failure(); + + bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); + if (hasCoreId && parser.parseInteger(coreId)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parser.parseLParen() || parser.parseRParen()) + return failure(); + + if (weights.size() != weightTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (weightArgs.size() != weights.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); + if (hasCoreId && result.attributes.get("coreId")) + return parser.emitError(parser.getCurrentLocation(), + "coreId cannot be specified both positionally and in attr-dict"); + + if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + + if (hasCoreId) + result.addAttribute("coreId", getI32Attr(parser, coreId)); + + Region* body = result.addRegion(); + applyArgumentTypes(weightTypes, weightArgs); + return parser.parseRegion(*body, weightArgs); +} + +void PimCoreBatchOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getLaneArgument()); + printer << " = 0 to " << getLaneCount() << " "; + + SmallVector weightArgs; + weightArgs.reserve(getWeights().size()); + for (unsigned index = 0; index < getWeights().size(); ++index) + weightArgs.push_back(getWeightArgument(index)); + printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); + printer << " "; + SmallVector inputArgs; + inputArgs.reserve(getInputs().size()); + for (unsigned index = 0; index < getInputs().size(); ++index) + inputArgs.push_back(getInputArgument(index)); + printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef()); @@ -49,51 +184,57 @@ void PimCoreBatchOp::print(OpAsmPrinter& printer) { printer.printOptionalAttrDict( (*this)->getAttrs(), {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); - printer << " "; - printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); printer << " : "; - if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane)) - printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren); - else - printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren); + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; - printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square); - printer << " -> ()"; + printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); + printer << " -> () "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); } ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) { + int64_t lowerBound = 0; int32_t laneCount = 0; + OpAsmParser::Argument laneArg; + SmallVector weightArgs; + SmallVector inputArgs; + SmallVector regionArgs; SmallVector weights; SmallVector inputs; SmallVector weightTypes; SmallVector inputTypes; SmallVector coreIds; - if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount) - || parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights) - || parseCompressedOperandList(parser, ListDelimiter::Square, inputs)) + if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound) + || parser.parseKeyword("to") || parser.parseInteger(laneCount)) + return failure(); + if (lowerBound != 0) + return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound"); + + if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights) + || parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); - bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds")); + bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - Region* body = result.addRegion(); - if (parser.parseRegion(*body)) - return failure(); - - if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes) - || parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow() - || parser.parseLParen() || parser.parseRParen()) + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parser.parseLParen() || parser.parseRParen()) return failure(); if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (weightArgs.size() != weights.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (inputArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match"); if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) return parser.emitError(parser.getCurrentLocation(), "coreIds cannot be specified both positionally and in attr-dict"); @@ -110,7 +251,15 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) { || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) { return failure(); } - return success(); + + Region* body = result.addRegion(); + laneArg.type = builder.getIndexType(); + regionArgs.push_back(laneArg); + applyArgumentTypes(weightTypes, weightArgs); + llvm::append_range(regionArgs, weightArgs); + applyArgumentTypes(inputTypes, inputArgs); + llvm::append_range(regionArgs, inputArgs); + return parser.parseRegion(*body, regionArgs); } void PimYieldOp::print(OpAsmPrinter& printer) { diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp index e9ce9cf..ad7139a 100644 --- a/src/PIM/Dialect/Pim/PimOpsVerify.cpp +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -1,5 +1,7 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Block.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/Support/LogicalResult.h" @@ -14,6 +16,52 @@ namespace pim { namespace { +static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { + if (isa(op)) + return operandIndex == 3; + if (isa(op)) + return operandIndex == 1; + if (isa(op)) + return operandIndex == 2; + return false; +} + +static Region* getParentRegion(Value value) { + if (auto blockArgument = dyn_cast(value)) + return blockArgument.getParentRegion(); + Operation* definingOp = value.getDefiningOp(); + return definingOp ? definingOp->getParentRegion() : nullptr; +} + +static bool isDefinedInsideRegion(Value value, Region& region) { + Region* parentRegion = getParentRegion(value); + return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); +} + +static bool isConstantExternalValue(Value value) { + Operation* definingOp = value.getDefiningOp(); + return definingOp && definingOp->hasTrait(); +} + +static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) { + bool hasFailure = false; + region.walk([&](Operation* op) { + for (OpOperand& operand : op->getOpOperands()) { + Value value = operand.get(); + if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value) + || isExplicitHostOperand(op, operand.getOperandNumber())) + continue; + + InFlightDiagnostic diagnostic = + ownerOp->emitOpError() << kind << " body may only directly reference external constants"; + diagnostic.attachNote(op->getLoc()) + << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); + hasFailure = true; + } + }); + return success(!hasFailure); +} + static bool haveSameShapedContainerKind(Type lhs, Type rhs) { return (isa(lhs) && isa(rhs)) || (isa(lhs) && isa(rhs)); } @@ -78,24 +126,46 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef coreI return success(); } -static FailureOr> getWeightShapeForVMM(Operation* op, size_t weightIndex) { - if (auto coreOp = op->getParentOfType()) { - if (weightIndex >= coreOp.getWeights().size()) - return failure(); - return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); - } - - if (auto coreBatchOp = op->getParentOfType()) { - if (weightIndex >= coreBatchOp.getWeights().size()) - return failure(); - return cast(coreBatchOp.getWeights()[weightIndex].getType()).getShape(); - } - - return failure(); +static FailureOr> getWeightShapeForVMM(Value weight) { + auto shapedType = dyn_cast(weight.getType()); + if (!shapedType) + return failure(); + return shapedType.getShape(); } } // namespace +LogicalResult PimCoreOp::verify() { + Block& block = getBody().front(); + if (block.getNumArguments() != getWeights().size()) + return emitError("core body must have one block argument per weight"); + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + if (getWeightArgument(weightIndex).getType() != weight.getType()) + return emitError("core weight block argument types must match weight operand types exactly"); + } + return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core"); +} + +LogicalResult PimCoreBatchOp::verify() { + if (getLaneCount() <= 0) + return emitError("laneCount must be positive"); + Block& block = getBody().front(); + unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size(); + if (block.getNumArguments() != expectedArgCount) + return emitError("core_batch body must have lane, weight, and input block arguments"); + if (!getLaneArgument().getType().isIndex()) + return emitError("core_batch first block argument must have index type"); + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + if (getWeightArgument(weightIndex).getType() != weight.getType()) + return emitError("core_batch weight block argument types must match weight operand types exactly"); + } + for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { + if (getInputArgument(inputIndex).getType() != input.getType()) + return emitError("core_batch input block argument types must match input operand types exactly"); + } + return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch"); +} + LogicalResult PimSendTensorOp::verify() { return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor"); } @@ -126,9 +196,9 @@ LogicalResult PimVMMOp::verify() { getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) return failure(); - auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex()); + auto matrixShapeOpt = getWeightShapeForVMM(getWeight()); if (failed(matrixShapeOpt)) - return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex"); + return emitError("weight must be a shaped value"); ArrayRef matrixShape = *matrixShapeOpt; auto vectorType = dyn_cast(getInput().getType()); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 1e724a8..1bf7c48 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -38,10 +38,10 @@ struct MemCopyHostToDevOpInterface replaceOpWithNewBufferizedOp(rewriter, memCopyHostToDevOp, deviceTargetMemRef.getType(), + memCopyHostToDevOp.getDeviceTargetOffset(), + memCopyHostToDevOp.getHostSourceOffset(), deviceTargetMemRef, hostSourceMemRef, - memCopyHostToDevOp.getDeviceTargetOffsetAttr(), - memCopyHostToDevOp.getHostSourceOffsetAttr(), memCopyHostToDevOp.getSizeAttr()); return success(); } @@ -96,10 +96,10 @@ struct MemCopyDevToHostOpInterface replaceOpWithNewBufferizedOp(rewriter, memCopyDevToHostOp, hostTargetMemRef.getType(), + memCopyDevToHostOp.getHostTargetOffset(), + memCopyDevToHostOp.getDeviceSourceOffset(), hostTargetMemRef, deviceSourceMemRef, - memCopyDevToHostOp.getHostTargetOffsetAttr(), - memCopyDevToHostOp.getDeviceSourceOffsetAttr(), memCopyDevToHostOp.getSizeAttr()); return success(); } @@ -151,12 +151,8 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel(rewriter, - op, - outputBufferOpt->getType(), - *outputBufferOpt, - receiveOp.getSizeAttr(), - receiveOp.getSourceCoreIdAttr()); + replaceOpWithNewBufferizedOp( + rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId()); return success(); } }; @@ -302,7 +298,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModelgetLoc(), rewriter), sendOp.getSizeAttr(), - sendOp.getTargetCoreIdAttr()); + sendOp.getTargetCoreId()); return success(); } }; @@ -368,6 +364,37 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel(op); + auto bbArg = dyn_cast(value); + if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front()) + return {}; + + unsigned weightIndex = bbArg.getArgNumber(); + return { + {&coreOp->getOpOperand(weightIndex), BufferRelation::Equivalent} + }; + } + + bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; } + + FailureOr getBufferType(Operation* op, + Value value, + const BufferizationOptions& options, + const BufferizationState& state, + SmallVector& invocationStack) const { + auto coreOp = cast(op); + auto bbArg = dyn_cast(value); + if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front()) + return failure(); + + Value tiedWeight = coreOp.getWeights()[bbArg.getArgNumber()]; + if (auto memRefType = dyn_cast(tiedWeight.getType())) + return memRefType; + + return bufferization::getBufferType(tiedWeight, options, state, invocationStack); + } + LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, @@ -375,7 +402,10 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel(op); bool alreadyBufferized = - llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa(weight.getType()); }); + llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa(weight.getType()); }) + && llvm::all_of(coreOp.getBody().front().getArguments(), [](BlockArgument arg) { + return !isa(arg.getType()) || isa(arg.getType()); + }); if (alreadyBufferized) return success(); @@ -420,9 +450,17 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel weightCount + 1) + operandIndex = weightCount + (argNumber - 1 - weightCount); + return { - {&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent} + {&coreBatchOp->getOpOperand(operandIndex), BufferRelation::Equivalent} }; } @@ -438,11 +476,21 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel(tiedInput.getType())) + unsigned argNumber = bbArg.getArgNumber(); + if (argNumber == 0) + return failure(); + + Value tiedOperand; + unsigned weightCount = coreBatchOp.getWeights().size(); + if (argNumber <= weightCount) + tiedOperand = coreBatchOp.getWeights()[argNumber - 1]; + else + tiedOperand = coreBatchOp.getInputs()[argNumber - 1 - weightCount]; + + if (auto memRefType = dyn_cast(tiedOperand.getType())) return memRefType; - return bufferization::getBufferType(tiedInput, options, state, invocationStack); + return bufferization::getBufferType(tiedOperand, options, state, invocationStack); } LogicalResult bufferize(Operation* op, @@ -454,8 +502,9 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel(weight.getType()); }) && llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa(input.getType()); }) - && llvm::all_of(coreBatchOp.getBody().front().getArguments(), - [](BlockArgument arg) { return isa(arg.getType()); }); + && llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) { + return !isa(arg.getType()) || isa(arg.getType()); + }); if (alreadyBufferized) return success(); @@ -553,6 +602,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel(op); + auto weightOpt = getBufferOrValue(rewriter, vmmOp.getWeight(), options, state); + if (failed(weightOpt)) + return failure(); + auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state); if (failed(inputOpt)) return failure(); @@ -564,7 +617,7 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt); + rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt); return success(); } }; diff --git a/src/PIM/Dialect/Spatial/Channels.cpp b/src/PIM/Dialect/Spatial/Channels.cpp index 59847e6..7facf82 100644 --- a/src/PIM/Dialect/Spatial/Channels.cpp +++ b/src/PIM/Dialect/Spatial/Channels.cpp @@ -1,5 +1,6 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Matchers.h" #include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" @@ -9,19 +10,62 @@ namespace onnx_mlir::spatial { namespace { -static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); } +static FailureOr getConstantI64(Value value) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + return constantValue.getSExtValue(); +} -static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); } +static FailureOr getConstantI32(Value value) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + return static_cast(constantValue.getSExtValue()); +} + +static FailureOr getChannelId(SpatChannelSendOp sendOp) { + return getConstantI64(sendOp.getChannelId()); +} + +static FailureOr getChannelId(SpatChannelReceiveOp receiveOp) { + return getConstantI64(receiveOp.getChannelId()); +} + +static FailureOr getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); } + +static FailureOr getSourceCoreId(SpatChannelReceiveOp receiveOp) { + return getConstantI32(receiveOp.getSourceCoreId()); +} + +static FailureOr getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); } + +static FailureOr getTargetCoreId(SpatChannelReceiveOp receiveOp) { + return getConstantI32(receiveOp.getTargetCoreId()); +} static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) { if (!endpoints.send || !endpoints.receive) return failure(); - if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) { + FailureOr sendSourceCoreId = getSourceCoreId(endpoints.send); + FailureOr receiveSourceCoreId = getSourceCoreId(endpoints.receive); + if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) { + endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands"); + return failure(); + } + if (*sendSourceCoreId != *receiveSourceCoreId) { endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive"); return failure(); } - if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) { + + FailureOr sendTargetCoreId = getTargetCoreId(endpoints.send); + FailureOr receiveTargetCoreId = getTargetCoreId(endpoints.receive); + if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) { + endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands"); + return failure(); + } + if (*sendTargetCoreId != *receiveTargetCoreId) { endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive"); return failure(); } @@ -46,20 +90,26 @@ Channels::Channels(func::FuncOp funcOp) { Channels::ChannelId Channels::allocate() { return nextChannelId++; } void Channels::insertSend(SpatChannelSendOp sendOp) { - ChannelId channelId = getChannelId(sendOp); - nextChannelId = std::max(nextChannelId, channelId + 1); - endpoints[channelId].send = sendOp; + FailureOr channelId = getChannelId(sendOp); + if (failed(channelId)) + return; + nextChannelId = std::max(nextChannelId, *channelId + 1); + endpoints[*channelId].send = sendOp; } void Channels::insertReceive(SpatChannelReceiveOp receiveOp) { - ChannelId channelId = getChannelId(receiveOp); - nextChannelId = std::max(nextChannelId, channelId + 1); - endpoints[channelId].receive = receiveOp; + FailureOr channelId = getChannelId(receiveOp); + if (failed(channelId)) + return; + nextChannelId = std::max(nextChannelId, *channelId + 1); + endpoints[*channelId].receive = receiveOp; } void Channels::eraseSend(SpatChannelSendOp sendOp) { - ChannelId channelId = getChannelId(sendOp); - auto it = endpoints.find(channelId); + FailureOr channelId = getChannelId(sendOp); + if (failed(channelId)) + return; + auto it = endpoints.find(*channelId); if (it == endpoints.end()) return; it->second.send = {}; @@ -68,8 +118,10 @@ void Channels::eraseSend(SpatChannelSendOp sendOp) { } void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) { - ChannelId channelId = getChannelId(receiveOp); - auto it = endpoints.find(channelId); + FailureOr channelId = getChannelId(receiveOp); + if (failed(channelId)) + return; + auto it = endpoints.find(*channelId); if (it == endpoints.end()) return; it->second.receive = {}; @@ -85,14 +137,20 @@ FailureOr Channels::lookup(ChannelId id) const { } FailureOr Channels::getReceiveFor(SpatChannelSendOp sendOp) const { - auto endpointsOr = lookup(getChannelId(sendOp)); + FailureOr channelId = getChannelId(sendOp); + if (failed(channelId)) + return failure(); + auto endpointsOr = lookup(*channelId); if (failed(endpointsOr) || !endpointsOr->receive) return failure(); return endpointsOr->receive; } FailureOr Channels::getSendFor(SpatChannelReceiveOp receiveOp) const { - auto endpointsOr = lookup(getChannelId(receiveOp)); + FailureOr channelId = getChannelId(receiveOp); + if (failed(channelId)) + return failure(); + auto endpointsOr = lookup(*channelId); if (failed(endpointsOr) || !endpointsOr->send) return failure(); return endpointsOr->send; diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 80e9a4d..93cd4b6 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -2,8 +2,12 @@ #define SPATIAL_DIALECT_H include "mlir/IR/OpBase.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/IR/BuiltinTypes.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/RegionKindInterface.td" +include "mlir/Interfaces/ParallelCombiningOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" def SpatialDialect : Dialect { let name = "spat"; @@ -22,7 +26,9 @@ def SpatTensor : // Execution //===----------------------------------------------------------------------===// -def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { +def SpatCompute : SpatOp<"compute", + [SingleBlock, AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { let summary = "Compute region with attached constant weights"; let arguments = (ins @@ -36,14 +42,20 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { let regions = (region SizedRegion<1>:$body); + let extraClassDeclaration = [{ + ::mlir::BlockArgument getWeightArgument(unsigned idx); + ::mlir::BlockArgument getInputArgument(unsigned idx); + }]; + let hasVerifier = 1; let hasFolder = 1; let hasCustomAssemblyFormat = 1; } def SpatComputeBatch : SpatOp<"compute_batch", - [SingleBlock, AttrSizedOperandSegments]> { - let summary = "Compressed batch of independent equivalent compute lanes"; + [SingleBlock, AttrSizedOperandSegments, + DeclareOpInterfaceMethods]> { + let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs"; let arguments = (ins I32Attr:$laneCount, @@ -57,10 +69,41 @@ def SpatComputeBatch : SpatOp<"compute_batch", let regions = (region SizedRegion<1>:$body); + let extraClassDeclaration = [{ + ::mlir::BlockArgument getLaneArgument(); + ::mlir::BlockArgument getWeightArgument(unsigned idx); + ::mlir::BlockArgument getInputArgument(unsigned idx); + ::mlir::BlockArgument getOutputArgument(unsigned idx); + }]; + let hasVerifier = 1; let hasCustomAssemblyFormat = 1; } +def SpatInParallelOp : SpatOp<"in_parallel", [ + Pure, + Terminator, + DeclareOpInterfaceMethods, + HasParent<"SpatComputeBatch">, + ] # GraphRegionNoTerminator.traits> { + let summary = "Parallel combining terminator for resultful spat.compute_batch"; + + let regions = (region SizedRegion<1>:$region); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins)>, + ]; + + let extraClassDeclaration = [{ + ::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps(); + ::mlir::OpResult getParentResult(int64_t idx); + }]; +} + def SpatYieldOp : SpatOp<"yield", [Terminator]> { let summary = "Yield results from a compute region"; @@ -110,14 +153,14 @@ def SpatChannelSendOp : SpatOp<"channel_send", []> { let summary = "Send a tensor through a logical channel"; let arguments = (ins - I64Attr:$channelId, - I32Attr:$sourceCoreId, - I32Attr:$targetCoreId, + Index:$channelId, + Index:$sourceCoreId, + Index:$targetCoreId, SpatTensor:$input ); let assemblyFormat = [{ - $input attr-dict `:` type($input) + $input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input) }]; } @@ -125,9 +168,9 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { let summary = "Receive a tensor from a logical channel"; let arguments = (ins - I64Attr:$channelId, - I32Attr:$sourceCoreId, - I32Attr:$targetCoreId + Index:$channelId, + Index:$sourceCoreId, + Index:$targetCoreId ); let results = (outs @@ -135,31 +178,33 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { ); let assemblyFormat = [{ - attr-dict `:` type($output) + `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output) }]; } -def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> { +def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> { let summary = "Send equal contiguous chunks of one tensor through logical channels"; let arguments = (ins - DenseI64ArrayAttr:$channelIds, - DenseI32ArrayAttr:$sourceCoreIds, - DenseI32ArrayAttr:$targetCoreIds, + Variadic:$channelIds, + Variadic:$sourceCoreIds, + Variadic:$targetCoreIds, SpatTensor:$input ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + $input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input) + }]; } -def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> { +def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> { let summary = "Receive equal contiguous chunks of one tensor from logical channels"; let arguments = (ins - DenseI64ArrayAttr:$channelIds, - DenseI32ArrayAttr:$sourceCoreIds, - DenseI32ArrayAttr:$targetCoreIds + Variadic:$channelIds, + Variadic:$sourceCoreIds, + Variadic:$targetCoreIds ); let results = (outs @@ -167,44 +212,50 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> { ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output) + }]; } -def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { +def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> { let summary = "Send per-lane tensors through logical channels in a batch body"; let arguments = (ins - DenseI64ArrayAttr:$channelIds, - DenseI32ArrayAttr:$sourceCoreIds, - DenseI32ArrayAttr:$targetCoreIds, + Variadic:$channelIds, + Variadic:$sourceCoreIds, + Variadic:$targetCoreIds, SpatTensor:$input ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + $input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input) + }]; } -def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> { +def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> { let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body"; let arguments = (ins - DenseI64ArrayAttr:$channelIds, - DenseI32ArrayAttr:$sourceCoreIds, - DenseI32ArrayAttr:$targetCoreIds, + Variadic:$channelIds, + Variadic:$sourceCoreIds, + Variadic:$targetCoreIds, SpatTensor:$input ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + $input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input) + }]; } -def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { +def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> { let summary = "Receive a per-lane tensor through logical channels in a batch body"; let arguments = (ins - DenseI64ArrayAttr:$channelIds, - DenseI32ArrayAttr:$sourceCoreIds, - DenseI32ArrayAttr:$targetCoreIds + Variadic:$channelIds, + Variadic:$sourceCoreIds, + Variadic:$targetCoreIds ); let results = (outs @@ -212,16 +263,18 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output) + }]; } -def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> { +def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> { let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body"; let arguments = (ins - DenseI64ArrayAttr:$channelIds, - DenseI32ArrayAttr:$sourceCoreIds, - DenseI32ArrayAttr:$targetCoreIds + Variadic:$channelIds, + Variadic:$sourceCoreIds, + Variadic:$targetCoreIds ); let results = (outs @@ -229,7 +282,9 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output) + }]; } //===----------------------------------------------------------------------===// @@ -240,7 +295,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> { let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins - I32Attr:$weightIndex, + SpatTensor:$weight, SpatTensor:$input ); @@ -251,7 +306,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> { let hasVerifier = 1; let assemblyFormat = [{ - `(` $input `)` attr-dict `:` type($input) `->` type($output) + `[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output) }]; } @@ -259,7 +314,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> { let summary = "Matrix-vector multiplication within a weighted compute operation"; let arguments = (ins - I32Attr:$weightIndex, + SpatTensor:$weight, SpatTensor:$input ); @@ -270,7 +325,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> { let hasVerifier = 1; let assemblyFormat = [{ - `(` $input `)` attr-dict `:` type($input) `->` type($output) + `[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output) }]; } diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index aae95d6..43aff68 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -1,10 +1,74 @@ #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include + using namespace mlir; namespace onnx_mlir { namespace spatial { +BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); } + +BlockArgument SpatCompute::getInputArgument(unsigned idx) { + return getBody().front().getArgument(getWeights().size() + idx); +} + +void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { + if (region.empty()) + return; + + for (unsigned index = 0; index < getWeights().size(); ++index) + setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); + + for (unsigned index = 0; index < getInputs().size(); ++index) + setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); +} + +BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); } + +BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); } + +BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) { + return getBody().front().getArgument(1 + getWeights().size() + idx); +} + +BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) { + return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx); +} + +void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { + if (region.empty()) + return; + + setNameFn(getLaneArgument(), "lane"); + + for (unsigned index = 0; index < getWeights().size(); ++index) + setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); + + for (unsigned index = 0; index < getInputs().size(); ++index) + setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); + + for (unsigned index = 0; index < getNumResults(); ++index) { + if (index == 0) { + setNameFn(getOutputArgument(index), "out"); + continue; + } + setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str()); + } +} + +void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) { + OpBuilder::InsertionGuard guard(builder); + Region* bodyRegion = result.addRegion(); + builder.createBlock(bodyRegion); +} + +OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); } + +llvm::iterator_range SpatInParallelOp::getYieldingOps() { + return getRegion().front().getOperations(); +} + void SpatialDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/src/PIM/Dialect/Spatial/SpatialOps.hpp b/src/PIM/Dialect/Spatial/SpatialOps.hpp index 15c6650..ce89ef3 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.hpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.hpp @@ -5,7 +5,9 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include #include diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 96286dc..5e517fe 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -23,22 +23,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy)); } -static void printChannelMetadata(OpAsmPrinter& printer, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds) { - printer << " channels "; - printCompressedIntegerList(printer, channelIds); - printer << " from "; - printCompressedIntegerList(printer, sourceCoreIds); - printer << " to "; - printCompressedIntegerList(printer, targetCoreIds); -} - -static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef values) { - return parser.getBuilder().getDenseI64ArrayAttr(values); -} - static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { return parser.getBuilder().getDenseI32ArrayAttr(values); } @@ -47,94 +31,89 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } -template -static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) { - printer << " "; - printer.printOperand(op.getInput()); - printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds()); - printer.printOptionalAttrDict(op->getAttrs(), - {op.getChannelIdsAttrName().getValue(), - op.getSourceCoreIdsAttrName().getValue(), - op.getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printer.printType(op.getInput().getType()); +static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef arguments) { + printer << "("; + for (auto [index, argument] : llvm::enumerate(arguments)) { + if (index != 0) + printer << ", "; + printer.printOperand(argument); + } + printer << ")"; } -template -static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) { - printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds()); - printer.printOptionalAttrDict(op->getAttrs(), - {op.getChannelIdsAttrName().getValue(), - op.getSourceCoreIdsAttrName().getValue(), - op.getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printer.printType(op.getOutput().getType()); -} - -static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) { - OpAsmParser::UnresolvedOperand input; - Type inputType; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - if (parser.parseOperand(input)) +static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl& arguments) { + if (parser.parseLParen()) return failure(); + if (succeeded(parser.parseOptionalRParen())) + return success(); - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) + OpAsmParser::Argument argument; + if (parser.parseArgument(argument)) + return failure(); + arguments.push_back(argument); + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseArgument(argument)) return failure(); + arguments.push_back(argument); } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) - return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - return parser.resolveOperand(input, inputType, result.operands); + return parser.parseRParen(); } -static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) { - Type outputType; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; +static void applyBatchRegionArgumentTypes(ArrayRef inputTypes, + ArrayRef weightTypes, + ArrayRef outputTypes, + OpAsmParser::Argument& laneArg, + SmallVectorImpl& weightArgs, + SmallVectorImpl& inputArgs, + SmallVectorImpl& outputArgs, + SmallVectorImpl& regionArgs, + Builder& builder) { + laneArg.type = builder.getIndexType(); + regionArgs.push_back(laneArg); + applyArgumentTypes(weightTypes, weightArgs); + llvm::append_range(regionArgs, weightArgs); + applyArgumentTypes(inputTypes, inputArgs); + applyArgumentTypes(outputTypes, outputArgs); + llvm::append_range(regionArgs, inputArgs); + llvm::append_range(regionArgs, outputArgs); +} - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } +static void +printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) { + printCompressedValueList(printer, arguments, delimiter); + printer << " = "; + printCompressedValueList(printer, operands, delimiter); +} - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)) +static ParseResult parseBoundValueList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& arguments, + SmallVectorImpl& operands) { + if (parseOpenDelimiter(parser, delimiter)) return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) { + if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) + return failure(); + return success(); } - result.addTypes(outputType); + if (parseOneCompressedArgumentEntry(parser, arguments)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedArgumentEntry(parser, arguments)) + return failure(); + auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { + switch (currentDelimiter) { + case ListDelimiter::Paren: + return parser.parseRParen(); + case ListDelimiter::Square: + return parser.parseRSquare(); + } + llvm_unreachable("unsupported delimiter"); + }; + if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) { + return failure(); + } return success(); } @@ -243,9 +222,17 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { void SpatCompute::print(OpAsmPrinter& printer) { printer << " "; - printCompressedValueList(printer, getWeights(), ListDelimiter::Square); + SmallVector weightArgs; + weightArgs.reserve(getWeights().size()); + for (unsigned index = 0; index < getWeights().size(); ++index) + weightArgs.push_back(getWeightArgument(index)); + printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; - printArgumentBindings(printer, getBody().front(), getInputs()); + SmallVector inputArgs; + inputArgs.reserve(getInputs().size()); + for (unsigned index = 0; index < getInputs().size(); ++index) + inputArgs.push_back(getInputArgument(index)); + printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) printer << " coreId " << coreIdAttr.getInt(); @@ -264,6 +251,7 @@ void SpatCompute::print(OpAsmPrinter& printer) { } ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { + SmallVector weightArgs; SmallVector regionArgs; SmallVector weights; SmallVector inputs; @@ -272,10 +260,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { SmallVector outputTypes; int32_t coreId = 0; - if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) + if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) return failure(); - if (parseArgumentBindings(parser, regionArgs, inputs)) + SmallVector inputArgs; + if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); @@ -292,9 +281,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (weightArgs.size() != weights.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (regionArgs.size() != inputs.size()) + if (inputArgs.size() != inputs.size()) return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) return parser.emitError(parser.getCurrentLocation(), @@ -313,19 +304,39 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { result.addTypes(outputTypes); Region* body = result.addRegion(); - applyArgumentTypes(inputTypes, regionArgs); + applyArgumentTypes(weightTypes, weightArgs); + applyArgumentTypes(inputTypes, inputArgs); + llvm::append_range(regionArgs, weightArgs); + llvm::append_range(regionArgs, inputArgs); return parser.parseRegion(*body, regionArgs); } void SpatComputeBatch::print(OpAsmPrinter& printer) { - printer << " lanes " << getLaneCount() << " "; - size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast(getLaneCount()) : 0; - if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane)) - printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square); - else - printCompressedValueList(printer, getWeights(), ListDelimiter::Square); printer << " "; - printArgumentBindings(printer, getBody().front(), getInputs()); + printer.printOperand(getLaneArgument()); + printer << " = 0 to " << getLaneCount(); + + printer << " "; + SmallVector weightArgs; + weightArgs.reserve(getWeights().size()); + for (unsigned index = 0; index < getWeights().size(); ++index) + weightArgs.push_back(getWeightArgument(index)); + printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); + printer << " "; + SmallVector inputArgs; + inputArgs.reserve(getInputs().size()); + for (unsigned index = 0; index < getInputs().size(); ++index) + inputArgs.push_back(getInputArgument(index)); + printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); + + if (getNumResults() != 0) { + printer << " shared_outs"; + SmallVector outputArgs; + outputArgs.reserve(getNumResults()); + for (unsigned index = 0; index < getNumResults(); ++index) + outputArgs.push_back(getOutputArgument(index)); + printBlockArgumentList(printer, outputArgs); + } if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) { printer << " coreIds "; @@ -337,10 +348,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) { {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); printer << " : "; - if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane)) - printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square); - else - printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; @@ -350,7 +358,12 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) { } ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { + int64_t lowerBound = 0; int32_t laneCount = 0; + OpAsmParser::Argument laneArg; + SmallVector weightArgs; + SmallVector inputArgs; + SmallVector outputArgs; SmallVector regionArgs; SmallVector weights; SmallVector inputs; @@ -359,14 +372,21 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) SmallVector outputTypes; SmallVector coreIds; - if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) + if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound) + || parser.parseKeyword("to") || parser.parseInteger(laneCount)) + return failure(); + if (lowerBound != 0) + return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound"); + + if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) return failure(); - if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights)) + if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); - if (parseArgumentBindings(parser, regionArgs, inputs)) - return failure(); + if (succeeded(parser.parseOptionalKeyword("shared_outs"))) + if (parseBlockArgumentList(parser, outputArgs)) + return failure(); bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) @@ -381,10 +401,15 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) if (weights.size() != weightTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (weightArgs.size() != weights.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (regionArgs.size() != inputs.size()) + if (inputArgs.size() != inputs.size()) return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); + if (outputArgs.size() != outputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "number of shared output bindings and result types must match"); if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) return parser.emitError(parser.getCurrentLocation(), "coreIds cannot be specified both positionally and in attr-dict"); @@ -403,119 +428,28 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) result.addTypes(outputTypes); Region* body = result.addRegion(); - applyArgumentTypes(inputTypes, regionArgs); + applyBatchRegionArgumentTypes( + inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder()); return parser.parseRegion(*body, regionArgs); } -void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); } - -ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) { - return parseTensorSendOp(parser, result); -} - -void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { +void SpatInParallelOp::print(OpAsmPrinter& printer) { printer << " "; - printer.printOperand(getInput()); - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printer.printType(getInput().getType()); + printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false); + printer.printOptionalAttrDict((*this)->getAttrs()); } -ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { - OpAsmParser::UnresolvedOperand input; - Type inputType; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - if (parser.parseOperand(input)) +ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) { + auto& builder = parser.getBuilder(); + std::unique_ptr region = std::make_unique(); + SmallVector regionArgs; + if (parser.parseRegion(*region, regionArgs)) return failure(); - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) - return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - return parser.resolveOperand(input, inputType, result.operands); -} - -void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); } - -ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) { - return parseTensorSendOp(parser, result); -} - -void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); } - -ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) { - return parseTensorReceiveOp(parser, result); -} - -void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printer.printType(getOutput().getType()); -} - -ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) { - Type outputType; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)) - return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - result.addTypes(outputType); - return success(); -} - -void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); } - -ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) { - return parseTensorReceiveOp(parser, result); + if (region->empty()) + OpBuilder(builder.getContext()).createBlock(region.get()); + result.addRegion(std::move(region)); + return parser.parseOptionalAttrDict(result.attributes); } } // namespace spatial diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 183c0cd..c597f47 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -1,6 +1,9 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" @@ -82,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter, return success(); } -static FailureOr> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) { - if (auto computeOp = weightedOp->getParentOfType()) - return cast(computeOp.getWeights()[weightIndex].getType()).getShape(); - - if (auto coreOp = weightedOp->getParentOfType()) - return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); - - if (auto batchOp = weightedOp->getParentOfType()) { - if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size()) - return failure(); - return cast(batchOp.getWeights()[weightIndex].getType()).getShape(); - } - - return failure(); +static FailureOr> getWeightShapeForWeightedOp(Value weight) { + auto shapedType = dyn_cast(weight.getType()); + if (!shapedType) + return failure(); + return shapedType.getShape(); } static FailureOr getParentBatchLaneCount(Operation* op) { @@ -105,15 +99,86 @@ static FailureOr getParentBatchLaneCount(Operation* op) { return batchOp.getLaneCount(); } -static LogicalResult verifyTensorChannelSizes(Operation* op, - Type type, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - StringRef kind) { - if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) +static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) { + if (batchOp.getNumResults() == 0) + return false; + auto blockArg = dyn_cast(value); + if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front()) + return false; + + unsigned argNumber = blockArg.getArgNumber(); + unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber(); + return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults(); +} + +static bool isConstantIndexLike(Value value) { + APInt constantValue; + return matchPattern(value, m_ConstantInt(&constantValue)); +} + +static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { + if (value == laneArg || isConstantIndexLike(value)) + return true; + + auto addOp = value.getDefiningOp(); + if (!addOp) + return false; + return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs())) + || (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs())); +} + +static LogicalResult +verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) { + auto sourceType = dyn_cast(sliceOp.getSource().getType()); + auto resultType = dyn_cast(sliceOp.getResult().getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) + return sliceOp.emitOpError() << kind << " requires static ranked tensor types"; + if (!sliceOp.hasUnitStride()) + return sliceOp.emitOpError() << kind << " requires unit strides"; + + for (int64_t size : sliceOp.getStaticSizes()) + if (ShapedType::isDynamic(size)) + return sliceOp.emitOpError() << kind << " requires static slice sizes"; + + auto offsets = sliceOp.getOffsets(); + for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) { + bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset); + if (!supported) + return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets"; + } + + return success(); +} + +static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp, + BlockArgument laneArg, + StringRef kind) { + RankedTensorType sourceType = sliceOp.getSourceType(); + RankedTensorType destType = sliceOp.getDestType(); + if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) + return sliceOp.emitOpError() << kind << " requires static ranked tensor types"; + if (!sliceOp.hasUnitStride()) + return sliceOp.emitOpError() << kind << " requires unit strides"; + + for (int64_t size : sliceOp.getStaticSizes()) + if (ShapedType::isDynamic(size)) + return sliceOp.emitOpError() << kind << " requires static slice sizes"; + + auto offsets = sliceOp.getOffsets(); + for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) { + bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset); + if (!supported) + return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets"; + } + + return success(); +} + +static LogicalResult verifyTensorChannelSizes( + Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) { + if (channelCount != sourceCoreCount || channelCount != targetCoreCount) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); - if (channelIds.empty()) + if (channelCount == 0) return op->emitError() << kind << " must carry at least one chunk"; auto shapedType = dyn_cast(type); @@ -125,40 +190,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op, return op->emitError() << kind << " requires byte-sized elements"; int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; - if (totalBytes % static_cast(channelIds.size()) != 0) + if (totalBytes % static_cast(channelCount) != 0) return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids"; return success(); } -static LogicalResult verifyBatchChannelSizes(Operation* op, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds) { - if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) +static LogicalResult +verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) { + if (channelCount != sourceCoreCount || channelCount != targetCoreCount) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); auto laneCount = getParentBatchLaneCount(op); if (failed(laneCount)) return op->emitError("must be nested inside spat.compute_batch"); - if (channelIds.size() != static_cast(*laneCount)) + if (channelCount != static_cast(*laneCount)) return op->emitError("channel metadata length must match parent laneCount"); return success(); } -static LogicalResult verifyTensorBatchChannelSizes(Operation* op, - Type type, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - StringRef kind) { - if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) +static LogicalResult verifyTensorBatchChannelSizes( + Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) { + if (channelCount != sourceCoreCount || channelCount != targetCoreCount) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); auto laneCount = getParentBatchLaneCount(op); if (failed(laneCount)) return op->emitError("must be nested inside spat.compute_batch"); - if (channelIds.empty() || channelIds.size() % static_cast(*laneCount) != 0) + if (channelCount == 0 || channelCount % static_cast(*laneCount) != 0) return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount"; auto shapedType = dyn_cast(type); @@ -169,7 +228,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op, if (elementBits <= 0 || elementBits % 8 != 0) return op->emitError() << kind << " requires byte-sized elements"; - int64_t chunkCount = static_cast(channelIds.size()) / *laneCount; + int64_t chunkCount = static_cast(channelCount) / *laneCount; int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; if (totalBytes % chunkCount != 0) return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane"; @@ -177,28 +236,59 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op, return success(); } -static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) { - auto yieldOp = dyn_cast_or_null(block.getTerminator()); - if (!yieldOp) - return op->emitError("body must terminate with spat.yield"); - if (outputTypes.empty()) { +static Region* getParentRegion(Value value) { + if (auto blockArg = dyn_cast(value)) + return blockArg.getOwner()->getParent(); + if (Operation* definingOp = value.getDefiningOp()) + return definingOp->getParentRegion(); + return nullptr; +} + +static bool isDefinedInsideRegion(Value value, Region& region) { + Region* parentRegion = getParentRegion(value); + return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); +} + +static bool isConstantExternalValue(Value value) { + Operation* definingOp = value.getDefiningOp(); + return definingOp && definingOp->hasTrait(); +} + +static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) { + bool hasFailure = false; + region.walk([&](Operation* op) { + for (OpOperand& operand : op->getOpOperands()) { + Value value = operand.get(); + if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)) + continue; + + InFlightDiagnostic diagnostic = ownerOp->emitOpError() + << kind << " body may only directly reference external constants"; + diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber() + << " is used by " << op->getName(); + hasFailure = true; + } + }); + return success(!hasFailure); +} + +static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) { + if (batchOp.getNumResults() == 0) { + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return batchOp.emitError("resultless compute_batch body must terminate with spat.yield"); if (yieldOp.getNumOperands() != 0) - return op->emitError("body yield must be empty when compute_batch has no results"); + return batchOp.emitError("resultless compute_batch body yield must be empty"); } - else { - if (yieldOp.getNumOperands() != 1) - return op->emitError("body yield must produce exactly one value"); - if (yieldOp.getOperand(0).getType() != outputTypes[0]) - return op->emitError("body yield type must match output type"); + else if (!isa_and_nonnull(block.getTerminator())) { + return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel"); } + BlockArgument laneArg = batchOp.getLaneArgument(); for (auto& bodyOp : block) { - if (auto wvmm = dyn_cast(&bodyOp)) - if (wvmm.getWeightIndex() < 0 || static_cast(wvmm.getWeightIndex()) >= weightsPerLane) - return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane"); - if (auto wmvm = dyn_cast(&bodyOp)) - if (wmvm.getWeightIndex() < 0 || static_cast(wmvm.getWeightIndex()) >= weightsPerLane) - return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane"); + if (auto extractSlice = dyn_cast(&bodyOp)) + if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice"))) + return failure(); } return success(); } @@ -206,9 +296,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp } // namespace LogicalResult SpatMVMOp::verify() { - auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); + auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight()); if (failed(matrixShapeOpt)) - return emitError("SpatMVMOp was not within a SpatCompute or Core op"); + return emitError("weight must be a shaped value"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); @@ -221,9 +311,9 @@ LogicalResult SpatMVMOp::verify() { } LogicalResult SpatVMMOp::verify() { - auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); + auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight()); if (failed(matrixShapeOpt)) - return emitError("SpatVMMOp was not within a SpatCompute or Core op"); + return emitError("weight must be a shaped value"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); @@ -347,13 +437,26 @@ LogicalResult verifyComputeResultsUses(Operation* op) { return !(op->getParentOfType() || op->getParentOfType()); }); })) { - return op->emitError("ComputeResult used directly inside another Compute" ); + return op->emitError("ComputeResult used directly inside another Compute"); } return success(); } LogicalResult SpatCompute::verify() { auto& block = getBody().front(); + unsigned expectedArgCount = getWeights().size() + getInputs().size(); + if (block.getNumArguments() != expectedArgCount) + return emitError("compute body must have weight and input block arguments"); + + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + if (getWeightArgument(weightIndex).getType() != weight.getType()) + return emitError("compute weight block argument types must match weight operand types exactly"); + } + for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { + if (getInputArgument(inputIndex).getType() != input.getType()) + return emitError("compute input block argument types must match input operand types exactly"); + } + if (block.mightHaveTerminator()) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) @@ -386,9 +489,11 @@ LogicalResult SpatCompute::verify() { } } - for (auto arg : block.getArguments()) - if (arg.use_empty()) + for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex) + if (getInputArgument(inputIndex).use_empty()) return emitError("ComputeOp block argument is not used"); + if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute"))) + return failure(); if (failed(verifyComputeResultsUses(this->getOperation()))) return failure(); return success(); @@ -397,44 +502,46 @@ LogicalResult SpatCompute::verify() { LogicalResult SpatChannelSendTensorOp::verify() { return verifyTensorChannelSizes(getOperation(), getInput().getType(), - getChannelIds(), - getSourceCoreIds(), - getTargetCoreIds(), + getChannelIds().size(), + getSourceCoreIds().size(), + getTargetCoreIds().size(), "channel_send_tensor"); } LogicalResult SpatChannelReceiveTensorOp::verify() { return verifyTensorChannelSizes(getOperation(), getOutput().getType(), - getChannelIds(), - getSourceCoreIds(), - getTargetCoreIds(), + getChannelIds().size(), + getSourceCoreIds().size(), + getTargetCoreIds().size(), "channel_receive_tensor"); } LogicalResult SpatChannelSendBatchOp::verify() { - return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + return verifyBatchChannelSizes( + getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size()); } LogicalResult SpatChannelSendTensorBatchOp::verify() { return verifyTensorBatchChannelSizes(getOperation(), getInput().getType(), - getChannelIds(), - getSourceCoreIds(), - getTargetCoreIds(), + getChannelIds().size(), + getSourceCoreIds().size(), + getTargetCoreIds().size(), "channel_send_tensor_batch"); } LogicalResult SpatChannelReceiveBatchOp::verify() { - return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + return verifyBatchChannelSizes( + getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size()); } LogicalResult SpatChannelReceiveTensorBatchOp::verify() { return verifyTensorBatchChannelSizes(getOperation(), getOutput().getType(), - getChannelIds(), - getSourceCoreIds(), - getTargetCoreIds(), + getChannelIds().size(), + getSourceCoreIds().size(), + getTargetCoreIds().size(), "channel_receive_tensor_batch"); } @@ -444,35 +551,6 @@ LogicalResult SpatComputeBatch::verify() { return emitError("laneCount must be positive"); auto laneCountSz = static_cast(count); - if (getWeights().size() % laneCountSz != 0) - return emitError("number of weights must be a multiple of laneCount"); - - if (!getInputs().empty() && getInputs().size() != laneCountSz) - return emitError("number of inputs must be either 0 or laneCount"); - if (!getOutputs().empty() && getOutputs().size() != laneCountSz) - return emitError("number of outputs must be either 0 or laneCount"); - - size_t weightsPerLane = getWeights().size() / laneCountSz; - for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) { - Type weightType = getWeights()[weightIndex].getType(); - for (size_t lane = 1; lane < laneCountSz; ++lane) - if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType) - return emitError("corresponding weights across lanes must have the same type"); - } - - if (!getInputs().empty()) { - Type inputType = getInputs()[0].getType(); - for (Value in : getInputs().drop_front()) - if (in.getType() != inputType) - return emitError("all inputs must have the same type"); - } - - if (!getOutputs().empty()) { - Type outputType = getOutputs()[0].getType(); - for (Value out : getOutputs().drop_front()) - if (out.getType() != outputType) - return emitError("all outputs must have the same type"); - } if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) { auto coreIdsAttr = dyn_cast(coreIdAttr); @@ -482,27 +560,64 @@ LogicalResult SpatComputeBatch::verify() { return emitError("compute_batch coreIds array length must match laneCount"); if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; })) return emitError("compute_batch coreIds values must be non-negative"); - llvm::SmallDenseSet seenCoreIds; + DenseSet seenCoreIds; for (int32_t coreId : coreIdsAttr.asArrayRef()) if (!seenCoreIds.insert(coreId).second) - return emitError("compute_batch coreIds values must be distinct"); + return emitError("compute_batch coreIds values must be unique"); } Block& block = getBody().front(); - if (getInputs().empty()) { - if (block.getNumArguments() != 0) - return emitError("compute_batch body must have no block arguments when there are no inputs"); + unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults(); + if (block.getNumArguments() != expectedArgCount) + return emitError("compute_batch body must have lane, weight, input, and output block arguments"); + if (!getLaneArgument().getType().isIndex()) + return emitError("compute_batch first block argument must have index type"); + + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + if (getWeightArgument(weightIndex).getType() != weight.getType()) + return emitError("compute_batch weight block argument types must match weight operand types exactly"); } - else { - if (block.getNumArguments() != 1) - return emitError("compute_batch body must have exactly one block argument"); - if (block.getArgument(0).getType() != getInputs()[0].getType()) - return emitError("body block argument type must match input type"); + for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { + BlockArgument blockArg = getInputArgument(inputIndex); + if (blockArg.getType() != input.getType()) + return emitError("compute_batch input block argument types must match input operand types exactly"); + } + for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) { + BlockArgument blockArg = getOutputArgument(resultIndex); + if (blockArg.getType() != resultType) + return emitError("compute_batch output block argument types must match result types exactly"); } if (failed(verifyComputeResultsUses(this->getOperation()))) return failure(); - return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane); + if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch"))) + return failure(); + return verifyBatchBody(*this, block); +} + +LogicalResult SpatInParallelOp::verify() { + auto batchOp = getOperation()->getParentOfType(); + if (!batchOp) + return emitOpError("expected spat.compute_batch parent"); + if (batchOp.getNumResults() == 0) + return emitOpError("requires a resultful spat.compute_batch parent"); + + BlockArgument laneArg = batchOp.getLaneArgument(); + for (Operation& op : getRegion().front().getOperations()) { + auto insertSliceOp = dyn_cast(&op); + if (!insertSliceOp) + return emitOpError("expected only tensor.parallel_insert_slice ops"); + + if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice"))) + return failure(); + + MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations(); + for (OpOperand& destination : destinations) + if (!isBatchOutputArgument(batchOp, destination.get())) + return op.emitOpError("may only insert into a compute_batch output block argument"); + } + + return success(); } } // namespace spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 1cd985d..cabf3b6 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1,4 +1,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" @@ -33,10 +37,50 @@ using spatial::getComputeInstanceTemplateBlock; using spatial::getComputeInstanceWeights; using spatial::getProducerValueRef; +static Value createIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) { + return getOrCreateHostIndexConstant(anchorOp, value, folder); +} + +static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { + SmallVector constants; + constants.reserve(values.size()); + for (int64_t value : values) + constants.push_back(createIndexConstant(anchorOp, value, folder)); + return constants; +} + +static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { + SmallVector constants; + constants.reserve(values.size()); + for (int32_t value : values) + constants.push_back(createIndexConstant(anchorOp, value, folder)); + return constants; +} + +static Value createIndexTensorConstant(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { + auto tensorType = RankedTensorType::get({static_cast(values.size())}, IndexType::get(anchorOp->getContext())); + auto tensorAttr = DenseIntElementsAttr::get(tensorType, values); + return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder); +} + +static Value createIndexTupleTensorConstant(Operation* anchorOp, + int64_t tupleCount, + int64_t tupleWidth, + ArrayRef values, + OperationFolder& folder) { + auto tensorType = + RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext())); + auto tensorAttr = DenseIntElementsAttr::get(tensorType, values); + return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder); +} + class MergeScheduleMaterializerImpl { public: explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp) - : func(funcOp), loc(funcOp.getLoc()), returnOp(cast(funcOp.getBody().front().getTerminator())) {} + : func(funcOp), + loc(funcOp.getLoc()), + returnOp(cast(funcOp.getBody().front().getTerminator())), + constantFolder(funcOp.getContext()) {} LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) { schedule = &scheduleResult; @@ -75,12 +119,15 @@ private: DenseMap weightToIndex; }; + using ProgramKey = std::pair; + struct RemoteSendInfo { ChannelInfo channelInfo; ComputeInstance consumer; size_t inputIndex = 0; size_t consumerOrder = 0; size_t sourceOrder = 0; + bool isTensorInput = false; }; struct RemoteReceiveEntry { @@ -90,11 +137,391 @@ private: size_t sourceOrder = 0; }; + struct BatchYieldInfo { + Value yieldedValue; + tensor::ParallelInsertSliceOp insertSlice; + }; + + struct TensorChannelInfo { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector producerInstances; + size_t resultIndex = 0; + }; + + struct ExtractRowsSendRun { + SpatExtractRowsOp extractRows; + int64_t firstRow = 0; + SmallVector sendCounts; + SmallVector channelSourceTargetTuples; + }; + + struct ExtractSliceSendRun { + tensor::ExtractSliceOp extractSlice; + SmallVector baseOffsets; + unsigned varyingDim = 0; + int64_t offsetStep = 0; + SmallVector sendCounts; + SmallVector channelSourceTargetTuples; + }; + static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) { return (static_cast(static_cast(channelInfo.sourceCoreId)) << 32) | static_cast(channelInfo.targetCoreId); } + static ProgramKey getProgramKey(const ScheduledTask& task) { return {task.cpu, 0}; } + + static bool isResultfulBatchInstance(const ComputeInstance& instance) { + auto batch = dyn_cast(instance.op); + return batch && batch.getNumResults() != 0; + } + + static SmallVector buildPrefixSums(ArrayRef counts) { + SmallVector prefixSums; + prefixSums.reserve(counts.size() + 1); + prefixSums.push_back(0); + int64_t running = 0; + for (int64_t count : counts) { + running += count; + prefixSums.push_back(running); + } + return prefixSums; + } + + FailureOr> collectResultfulBatchYieldInfo(SpatComputeBatch batch) { + Block& block = batch.getBody().front(); + auto inParallel = dyn_cast(block.getTerminator()); + if (!inParallel) + return failure(); + + SmallVector resultInfo(batch.getNumResults()); + DenseMap resultIndexByOutputArg; + for (unsigned resultIndex = 0; resultIndex < batch.getNumResults(); ++resultIndex) + resultIndexByOutputArg[batch.getOutputArgument(resultIndex)] = resultIndex; + + for (Operation& op : inParallel.getRegion().front()) { + auto insertSlice = dyn_cast(&op); + if (!insertSlice) + return failure(); + auto outputArg = dyn_cast(insertSlice.getDest()); + auto resultIndexIt = resultIndexByOutputArg.find(outputArg); + if (resultIndexIt == resultIndexByOutputArg.end()) + return failure(); + resultInfo[resultIndexIt->second] = {insertSlice.getSource(), insertSlice}; + } + + if (llvm::any_of(resultInfo, [](const BatchYieldInfo& info) { return !info.yieldedValue; })) + return failure(); + return resultInfo; + } + + SmallVector getTaskOutputTypes(const ComputeInstance& instance) { + if (!isResultfulBatchInstance(instance)) + return getComputeInstanceOutputTypes(instance); + + auto batch = cast(instance.op); + FailureOr> yieldInfo = collectResultfulBatchYieldInfo(batch); + if (failed(yieldInfo)) + return {}; + + SmallVector outputTypes; + outputTypes.reserve(yieldInfo->size()); + for (const BatchYieldInfo& info : *yieldInfo) + outputTypes.push_back(info.yieldedValue.getType()); + return outputTypes; + } + + bool tryCollectExtractRowsSendRun(ArrayRef> sendInfosByResult, + ArrayRef taskYieldValues, + size_t startIndex, + ExtractRowsSendRun& run, + size_t& nextIndex) { + auto firstResult = dyn_cast(taskYieldValues[startIndex]); + auto extractRows = firstResult ? dyn_cast(firstResult.getOwner()) : nullptr; + if (!extractRows || sendInfosByResult[startIndex].empty()) + return false; + + auto inputType = dyn_cast(extractRows.getInput().getType()); + auto rowType = dyn_cast(taskYieldValues[startIndex].getType()); + if (!inputType || !rowType || !inputType.hasStaticShape() || !rowType.hasStaticShape() || inputType.getRank() != 2 + || rowType.getRank() != 2) + return false; + + run = {}; + run.extractRows = extractRows; + run.firstRow = firstResult.getResultNumber(); + + unsigned expectedRow = firstResult.getResultNumber(); + size_t index = startIndex; + while (index < taskYieldValues.size()) { + auto result = dyn_cast(taskYieldValues[index]); + if (!result || result.getOwner() != extractRows.getOperation() || result.getResultNumber() != expectedRow) + break; + + const SmallVector& sendInfos = sendInfosByResult[index]; + run.sendCounts.push_back(static_cast(sendInfos.size())); + for (const RemoteSendInfo& sendInfo : sendInfos) { + run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.channelId); + run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.sourceCoreId); + run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.targetCoreId); + } + + ++index; + ++expectedRow; + } + + nextIndex = index; + return run.sendCounts.size() > 1 && run.channelSourceTargetTuples.size() > 3; + } + + bool tryCollectExtractSliceSendRun(ArrayRef> sendInfosByResult, + ArrayRef taskYieldValues, + size_t startIndex, + ExtractSliceSendRun& run, + size_t& nextIndex) { + auto firstSlice = taskYieldValues[startIndex].getDefiningOp(); + if (!firstSlice || sendInfosByResult[startIndex].empty()) + return false; + if (llvm::any_of(firstSlice.getStaticOffsets(), ShapedType::isDynamic) + || llvm::any_of(firstSlice.getStaticSizes(), ShapedType::isDynamic) + || llvm::any_of(firstSlice.getStaticStrides(), ShapedType::isDynamic)) + return false; + + auto sourceType = dyn_cast(firstSlice.getSourceType()); + auto resultType = dyn_cast(firstSlice.getResultType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) + return false; + + ArrayRef firstOffsets = firstSlice.getStaticOffsets(); + ArrayRef staticSizes = firstSlice.getStaticSizes(); + ArrayRef staticStrides = firstSlice.getStaticStrides(); + + std::optional varyingDim; + int64_t offsetStep = 0; + run = {}; + run.extractSlice = firstSlice; + run.baseOffsets.assign(firstOffsets.begin(), firstOffsets.end()); + + size_t index = startIndex; + while (index < taskYieldValues.size()) { + auto slice = taskYieldValues[index].getDefiningOp(); + if (!slice || slice.getSource() != firstSlice.getSource() || slice.getResultType() != firstSlice.getResultType() + || slice.getStaticSizes() != staticSizes || slice.getStaticStrides() != staticStrides + || llvm::any_of(slice.getStaticOffsets(), ShapedType::isDynamic)) + break; + + ArrayRef offsets = slice.getStaticOffsets(); + if (index != startIndex) { + SmallVector differingDims; + for (auto [dim, pair] : llvm::enumerate(llvm::zip(firstOffsets, offsets))) + if (std::get<0>(pair) != std::get<1>(pair)) + differingDims.push_back(dim); + if (differingDims.size() != 1) + break; + + unsigned dim = differingDims.front(); + int64_t expectedOffset = firstOffsets[dim] + static_cast(index - startIndex) * offsetStep; + if (!varyingDim) { + varyingDim = dim; + offsetStep = offsets[dim] - firstOffsets[dim]; + expectedOffset = offsets[dim]; + } + if (offsetStep <= 0 || *varyingDim != dim || offsets[dim] != expectedOffset) + break; + } + + const SmallVector& sendInfos = sendInfosByResult[index]; + run.sendCounts.push_back(static_cast(sendInfos.size())); + for (const RemoteSendInfo& sendInfo : sendInfos) { + run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.channelId); + run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.sourceCoreId); + run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.targetCoreId); + } + + ++index; + } + + nextIndex = index; + if (!varyingDim) + return false; + run.varyingDim = *varyingDim; + run.offsetStep = offsetStep; + return run.sendCounts.size() > 1 && run.channelSourceTargetTuples.size() > 3; + } + + void emitInnerSendLoop(Operation* hostAnchor, + IRRewriter& rewriter, + Value sliceValue, + Value lower, + Value upper, + Value channelSourceTargetTuples) { + Value step = getOrCreateHostIndexConstant(hostAnchor, 1, constantFolder); + auto innerLoop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {}); + rewriter.setInsertionPointToStart(innerLoop.getBody()); + Value sendIndex = innerLoop.getInductionVar(); + Value tupleChannelIndex = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); + Value tupleSourceIndex = getOrCreateHostIndexConstant(hostAnchor, 1, constantFolder); + Value tupleTargetIndex = getOrCreateHostIndexConstant(hostAnchor, 2, constantFolder); + Value channelIdIndex = + tensor::ExtractOp::create(rewriter, loc, channelSourceTargetTuples, ValueRange {sendIndex, tupleChannelIndex}); + Value sourceCoreIdIndex = + tensor::ExtractOp::create(rewriter, loc, channelSourceTargetTuples, ValueRange {sendIndex, tupleSourceIndex}); + Value targetCoreIdIndex = + tensor::ExtractOp::create(rewriter, loc, channelSourceTargetTuples, ValueRange {sendIndex, tupleTargetIndex}); + spatial::SpatChannelSendOp::create(rewriter, loc, channelIdIndex, sourceCoreIdIndex, targetCoreIdIndex, sliceValue); + rewriter.setInsertionPointAfter(innerLoop); + } + + void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) { + SmallVector prefixSums = buildPrefixSums(run.sendCounts); + Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder); + Value channelSourceTargetTuples = createIndexTupleTensorConstant( + hostAnchor, + static_cast(run.channelSourceTargetTuples.size() / 3), + 3, + run.channelSourceTargetTuples, + constantFolder); + + Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); + Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast(run.sendCounts.size()), constantFolder); + Value step = getOrCreateHostIndexConstant(hostAnchor, 1, constantFolder); + auto outerLoop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {}); + rewriter.setInsertionPointToStart(outerLoop.getBody()); + + Value rowIndex = outerLoop.getInductionVar(); + if (run.firstRow != 0) { + Value firstRow = getOrCreateHostIndexConstant(hostAnchor, run.firstRow, constantFolder); + rowIndex = arith::AddIOp::create(rewriter, loc, rowIndex, firstRow); + } + + auto rowType = cast(run.extractRows.getResult(0).getType()); + int64_t rowHeight = rowType.getShape()[0]; + if (rowHeight != 1) { + Value rowHeightValue = getOrCreateHostIndexConstant(hostAnchor, rowHeight, constantFolder); + rowIndex = arith::MulIOp::create(rewriter, loc, rowIndex, rowHeightValue); + } + + auto inputType = cast(run.extractRows.getInput().getType()); + SmallVector offsets = {rowIndex, rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(rowType.getShape()[0]), + rewriter.getIndexAttr(inputType.getShape()[1])}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value extractedRow = + tensor::ExtractSliceOp::create(rewriter, loc, rowType, run.extractRows.getInput(), offsets, sizes, strides) + .getResult(); + + Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step); + Value innerLower = + tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()}); + Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex}); + emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples); + rewriter.setInsertionPointAfter(outerLoop); + } + + void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) { + SmallVector prefixSums = buildPrefixSums(run.sendCounts); + Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder); + Value channelSourceTargetTuples = createIndexTupleTensorConstant( + hostAnchor, + static_cast(run.channelSourceTargetTuples.size() / 3), + 3, + run.channelSourceTargetTuples, + constantFolder); + + Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); + Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast(run.sendCounts.size()), constantFolder); + Value step = getOrCreateHostIndexConstant(hostAnchor, 1, constantFolder); + auto outerLoop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {}); + rewriter.setInsertionPointToStart(outerLoop.getBody()); + + SmallVector offsets; + offsets.reserve(run.baseOffsets.size()); + for (auto [dim, offset] : llvm::enumerate(run.baseOffsets)) { + if (dim != run.varyingDim) { + offsets.push_back(rewriter.getIndexAttr(offset)); + continue; + } + + Value varyingOffset = outerLoop.getInductionVar(); + if (run.offsetStep != 1) { + Value offsetStep = getOrCreateHostIndexConstant(hostAnchor, run.offsetStep, constantFolder); + varyingOffset = arith::MulIOp::create(rewriter, loc, varyingOffset, offsetStep); + } + if (offset != 0) { + Value baseOffset = getOrCreateHostIndexConstant(hostAnchor, offset, constantFolder); + varyingOffset = arith::AddIOp::create(rewriter, loc, varyingOffset, baseOffset); + } + offsets.push_back(varyingOffset); + } + + SmallVector sizes; + SmallVector strides; + for (int64_t size : run.extractSlice.getStaticSizes()) + sizes.push_back(rewriter.getIndexAttr(size)); + for (int64_t stride : run.extractSlice.getStaticStrides()) + strides.push_back(rewriter.getIndexAttr(stride)); + + Value extractedSlice = + tensor::ExtractSliceOp::create( + rewriter, loc, run.extractSlice.getResultType(), run.extractSlice.getSource(), offsets, sizes, strides) + .getResult(); + + Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step); + Value innerLower = + tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()}); + Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex}); + emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples); + rewriter.setInsertionPointAfter(outerLoop); + } + + bool tryEmitCompactSendLoops(Operation* hostAnchor, + IRRewriter& rewriter, + ArrayRef> sendInfosByResult, + ArrayRef taskYieldValues, + size_t startIndex, + size_t& nextIndex) { + ExtractRowsSendRun extractRowsRun; + if (tryCollectExtractRowsSendRun(sendInfosByResult, taskYieldValues, startIndex, extractRowsRun, nextIndex)) { + emitExtractRowsSendRun(hostAnchor, rewriter, extractRowsRun); + return true; + } + + ExtractSliceSendRun extractSliceRun; + if (tryCollectExtractSliceSendRun(sendInfosByResult, taskYieldValues, startIndex, extractSliceRun, nextIndex)) { + emitExtractSliceSendRun(hostAnchor, rewriter, extractSliceRun); + return true; + } + + return false; + } + + size_t getLoopableResultlessBatchRunLength(ArrayRef programTasks, size_t startIndex) { + const ScheduledTask& firstTask = programTasks[startIndex]; + auto batch = dyn_cast(firstTask.computeInstance.op); + if (!batch || batch.getNumResults() != 0 || firstTask.computeInstance.laneCount != 1) + return 1; + + SmallVector firstInputs = getComputeInstanceInputs(firstTask.computeInstance); + SmallVector firstWeights = getComputeInstanceWeights(firstTask.computeInstance); + size_t runLength = 1; + uint32_t nextLane = firstTask.computeInstance.laneStart + 1; + for (size_t index = startIndex + 1; index < programTasks.size(); ++index) { + const ScheduledTask& candidate = programTasks[index]; + if (candidate.computeInstance.op != firstTask.computeInstance.op || candidate.computeInstance.laneCount != 1 + || candidate.computeInstance.laneStart != nextLane) + break; + if (getComputeInstanceInputs(candidate.computeInstance) != firstInputs + || getComputeInstanceWeights(candidate.computeInstance) != firstWeights) + break; + + ++runLength; + ++nextLane; + } + return runLength; + } + void collectScheduledTasks() { for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) { oldComputeOps.insert(scheduledInstance.op); @@ -113,21 +540,35 @@ private: for (const ScheduledTask& task : scheduledTasks) { taskByComputeInstance[task.computeInstance] = task; tasksByCpu[task.cpu].push_back(task); + ProgramKey programKey = getProgramKey(task); + tasksByProgram[programKey].push_back(task); + if (seenPrograms.insert(programKey).second) + orderedPrograms.push_back(programKey); markCpuSeen(task.cpu); } llvm::sort(orderedCpus); + llvm::sort(orderedPrograms, [](const ProgramKey& lhs, const ProgramKey& rhs) { + if (lhs.second != rhs.second) + return lhs.second < rhs.second; + return lhs.first < rhs.first; + }); for (size_t cpu : orderedCpus) llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { return lhs.orderWithinCpu < rhs.orderWithinCpu; }); + for (ProgramKey programKey : orderedPrograms) + llvm::stable_sort(tasksByProgram[programKey], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { + return lhs.orderWithinCpu < rhs.orderWithinCpu; + }); } void collectExternalInputsAndWeights() { - for (size_t cpu : orderedCpus) { - for (const ScheduledTask& task : tasksByCpu[cpu]) { - auto& thisCpuWeights = cpuWeights[cpu]; - auto& thisSeenWeights = seenWeightsByCpu[cpu]; + for (ProgramKey programKey : orderedPrograms) { + size_t cpu = programKey.first; + for (const ScheduledTask& task : tasksByProgram[programKey]) { + auto& thisCpuWeights = cpuWeights[programKey]; + auto& thisSeenWeights = seenWeightsByProgram[programKey]; auto taskWeights = getComputeInstanceWeights(task.computeInstance); for (Value weight : taskWeights) if (thisSeenWeights.insert(weight).second) @@ -136,26 +577,91 @@ private: auto taskInputs = getComputeInstanceInputs(task.computeInstance); auto& remoteInputs = remoteInputsByTask[task.computeInstance]; remoteInputs.resize(taskInputs.size()); + auto& remoteTensorInputs = remoteTensorInputsByTask[task.computeInstance]; + remoteTensorInputs.resize(taskInputs.size()); for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { - if (auto producerRef = getProducerValueRef(input)) { - auto producerIt = taskByComputeInstance.find(producerRef->instance); - if (producerIt->second.cpu != cpu) { + bool isExternalInput = true; + if (auto producerBatch = dyn_cast_or_null(input.getDefiningOp()); + producerBatch && producerBatch.getNumResults() != 0) { + size_t resultIndex = cast(input).getResultNumber(); + TensorChannelInfo tensorInfo; + tensorInfo.resultIndex = resultIndex; + tensorInfo.channelIds.reserve(static_cast(producerBatch.getLaneCount())); + tensorInfo.sourceCoreIds.reserve(static_cast(producerBatch.getLaneCount())); + tensorInfo.targetCoreIds.reserve(static_cast(producerBatch.getLaneCount())); + tensorInfo.producerInstances.reserve(static_cast(producerBatch.getLaneCount())); + + bool foundAllLaneProducers = true; + for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) { + ComputeInstance producerInstance = getBatchChunkForLane(producerBatch, lane); + auto producerIt = taskByComputeInstance.find(producerInstance); + if (producerIt == taskByComputeInstance.end()) { + foundAllLaneProducers = false; + break; + } + ChannelInfo info { - (*nextChannelId)++, + producerIt->second.cpu == cpu ? -1 : (*nextChannelId)++, static_cast(producerIt->second.cpu), static_cast(cpu), }; - remoteInputs[inputIndex] = info; - auto& perResultChannels = remoteSendsByTask[producerRef->instance]; - if (perResultChannels.empty()) - perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size()); - perResultChannels[producerRef->resultIndex].push_back( - {info, task.computeInstance, inputIndex, task.orderWithinCpu, 0}); + tensorInfo.channelIds.push_back(info.channelId); + tensorInfo.sourceCoreIds.push_back(info.sourceCoreId); + tensorInfo.targetCoreIds.push_back(info.targetCoreId); + tensorInfo.producerInstances.push_back(producerInstance); + + if (producerIt->second.cpu != cpu) { + auto& perResultChannels = remoteSendsByTask[producerInstance]; + if (perResultChannels.empty()) + perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size()); + perResultChannels[resultIndex].push_back( + {info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, true}); + } + } + + if (foundAllLaneProducers) { + remoteTensorInputs[inputIndex] = std::move(tensorInfo); + continue; } - continue; } - if (seenExternalInputsByCpu[cpu].insert(input).second) - cpuExternalInputs[cpu].push_back(input); + + if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) { + auto producerIt = taskByComputeInstance.find(producerRef->instance); + if (producerIt != taskByComputeInstance.end()) { + isExternalInput = false; + if (producerIt->second.cpu != cpu) { + ChannelInfo info { + (*nextChannelId)++, + static_cast(producerIt->second.cpu), + static_cast(cpu), + }; + remoteInputs[inputIndex] = info; + auto& perResultChannels = remoteSendsByTask[producerRef->instance]; + if (perResultChannels.empty()) + perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size()); + perResultChannels[producerRef->resultIndex].push_back( + {info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, false}); + } + } + } + if (isExternalInput && seenExternalInputsByProgram[programKey].insert(input).second) + cpuExternalInputs[programKey].push_back(input); + } + + if (isResultfulBatchInstance(task.computeInstance)) { + auto batch = cast(task.computeInstance.op); + for (unsigned resultIndex = 0; resultIndex < batch.getNumResults(); ++resultIndex) { + bool hasExternalUser = false; + for (Operation* user : batch.getResult(resultIndex).getUsers()) { + if (!oldComputeOps.contains(user)) { + hasExternalUser = true; + break; + } + } + if (hasExternalUser) + cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex}); + } + continue; } auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance); @@ -168,7 +674,7 @@ private: hasExternalUser = true; } if (hasExternalUser) - cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex}); + cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex}); } } } @@ -184,6 +690,8 @@ private: continue; for (auto& sendInfos : sendsIt->second) { for (RemoteSendInfo& sendInfo : sendInfos) { + if (sendInfo.isTensorInput) + continue; uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++; auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder); @@ -203,6 +711,8 @@ private: for (auto& taskSends : remoteSendsByTask) { for (auto& sendInfos : taskSends.second) { for (RemoteSendInfo& sendInfo : sendInfos) { + if (sendInfo.isTensorInput) + continue; uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); if (pairsNeedingReceiveReorder.contains(pairKey)) reorderedSendsByPair[pairKey].push_back(&sendInfo); @@ -230,6 +740,8 @@ private: for (const auto& taskSends : remoteSendsByTask) { for (const auto& sendInfos : taskSends.second) { for (const RemoteSendInfo& sendInfo : sendInfos) { + if (sendInfo.isTensorInput) + continue; auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer); assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send"); assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range"); @@ -242,6 +754,8 @@ private: for (auto& taskSends : remoteSendsByTask) { for (const auto& sendInfos : taskSends.second) { for (const RemoteSendInfo& sendInfo : sendInfos) { + if (sendInfo.isTensorInput) + continue; uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); if (!pairsNeedingReceiveReorder.contains(pairKey)) continue; @@ -265,30 +779,36 @@ private: void createCpuComputeOps() { IRRewriter rewriter(func.getContext()); - for (size_t cpu : orderedCpus) { + for (ProgramKey programKey : orderedPrograms) { + size_t cpu = programKey.first; SmallVector operands; - operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size()); - llvm::append_range(operands, cpuWeights[cpu]); - llvm::append_range(operands, cpuExternalInputs[cpu]); + operands.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size()); + llvm::append_range(operands, cpuWeights[programKey]); + llvm::append_range(operands, cpuExternalInputs[programKey]); SmallVector resultTypes; - resultTypes.reserve(cpuExternalOutputs[cpu].size()); - for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { + resultTypes.reserve(cpuExternalOutputs[programKey].size()); + for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) { ScheduledTask task = taskByComputeInstance.at(outputRef.instance); - resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]); + SmallVector outputTypes = getTaskOutputTypes(task.computeInstance); + resultTypes.push_back(outputTypes[outputRef.resultIndex]); } rewriter.setInsertionPoint(returnOp); auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands)); newCompute.getProperties().setOperandSegmentSizes( - {static_cast(cpuWeights[cpu].size()), static_cast(cpuExternalInputs[cpu].size())}); + {static_cast(cpuWeights[programKey].size()), static_cast(cpuExternalInputs[programKey].size())}); newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast(cpu))); SmallVector blockArgTypes; SmallVector blockArgLocs; - blockArgTypes.reserve(cpuExternalInputs[cpu].size()); - blockArgLocs.reserve(cpuExternalInputs[cpu].size()); - for (Value input : cpuExternalInputs[cpu]) { + blockArgTypes.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size()); + blockArgLocs.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size()); + for (Value weight : cpuWeights[programKey]) { + blockArgTypes.push_back(weight.getType()); + blockArgLocs.push_back(loc); + } + for (Value input : cpuExternalInputs[programKey]) { blockArgTypes.push_back(input.getType()); blockArgLocs.push_back(loc); } @@ -297,16 +817,27 @@ private: CpuProgram program; program.op = newCompute; - for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu])) + for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[programKey])) program.weightToIndex[weight] = weightIndex; - for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu])) - program.externalInputMap[input] = newBlock->getArgument(inputIndex); - for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) { + for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[programKey])) + program.externalInputMap[input] = newCompute.getInputArgument(inputIndex); + for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[programKey])) { ScheduledTask task = taskByComputeInstance.at(outputRef.instance); + if (isResultfulBatchInstance(task.computeInstance)) { + auto batch = cast(task.computeInstance.op); + auto& batchResults = resultfulBatchLaneResults[batch.getOperation()]; + if (batchResults.empty()) + batchResults.resize(batch.getNumResults()); + auto& laneResults = batchResults[outputRef.resultIndex]; + if (laneResults.empty()) + laneResults.resize(static_cast(batch.getLaneCount())); + laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex); + continue; + } oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] = newCompute.getResult(resultIndex); } - cpuPrograms[cpu] = std::move(program); + cpuPrograms[programKey] = std::move(program); } } @@ -336,12 +867,11 @@ private: if (consumerInputs.size() <= entry.inputIndex) return failure(); Type inputType = consumerInputs[entry.inputIndex].getType(); - auto receive = spatial::SpatChannelReceiveOp::create(rewriter, - loc, - inputType, - rewriter.getI64IntegerAttr(entry.channelInfo.channelId), - rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId)); + Value channelId = createIndexConstant(entry.consumer.op, entry.channelInfo.channelId, constantFolder); + Value sourceCoreId = createIndexConstant(entry.consumer.op, entry.channelInfo.sourceCoreId, constantFolder); + Value targetCoreId = createIndexConstant(entry.consumer.op, entry.channelInfo.targetCoreId, constantFolder); + auto receive = + spatial::SpatChannelReceiveOp::create(rewriter, loc, inputType, channelId, sourceCoreId, targetCoreId); auto& receivedInputs = preReceivedInputsByTask[entry.consumer]; if (receivedInputs.size() <= entry.inputIndex) @@ -355,12 +885,16 @@ private: } LogicalResult cloneTaskBodies() { - for (size_t cpu : orderedCpus) { - CpuProgram& program = cpuPrograms[cpu]; + DenseMap> receiveQueueIndicesByCpu; + DenseMap>> preReceivedInputsByCpu; + + for (ProgramKey programKey : orderedPrograms) { + size_t cpu = programKey.first; + CpuProgram& program = cpuPrograms[programKey]; IRRewriter rewriter(func.getContext()); rewriter.setInsertionPointToEnd(&program.op.getBody().front()); - DenseMap receiveQueueIndices; - DenseMap> preReceivedInputsByTask; + auto& receiveQueueIndices = receiveQueueIndicesByCpu[cpu]; + auto& preReceivedInputsByTask = preReceivedInputsByCpu[cpu]; auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional { auto inputsIt = preReceivedInputsByTask.find(consumer); @@ -372,7 +906,9 @@ private: return value; }; - for (const ScheduledTask& task : tasksByCpu[cpu]) { + ArrayRef programTasks = tasksByProgram[programKey]; + for (size_t taskIndex = 0; taskIndex < programTasks.size(); ++taskIndex) { + const ScheduledTask& task = programTasks[taskIndex]; SmallVector taskInputs = getComputeInstanceInputs(task.computeInstance); auto taskWeights = getComputeInstanceWeights(task.computeInstance); Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance); @@ -380,8 +916,57 @@ private: SmallVector resolvedInputs; resolvedInputs.reserve(taskInputs.size()); auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance); + auto remoteTensorInputsIt = remoteTensorInputsByTask.find(task.computeInstance); for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { - auto producerRef = getProducerValueRef(input); + if (remoteTensorInputsIt != remoteTensorInputsByTask.end() && inputIndex < remoteTensorInputsIt->second.size() + && remoteTensorInputsIt->second[inputIndex]) { + const TensorChannelInfo& tensorInfo = *remoteTensorInputsIt->second[inputIndex]; + bool hasLocalProducer = llvm::is_contained(tensorInfo.sourceCoreIds, static_cast(cpu)); + if (!hasLocalProducer) { + auto receive = spatial::SpatChannelReceiveTensorOp::create( + rewriter, + loc, + input.getType(), + createIndexConstants(program.op, tensorInfo.channelIds, constantFolder), + createIndexConstants(program.op, tensorInfo.sourceCoreIds, constantFolder), + createIndexConstants(program.op, tensorInfo.targetCoreIds, constantFolder)); + resolvedInputs.push_back(receive.getOutput()); + continue; + } + + SmallVector laneValues; + laneValues.reserve(tensorInfo.producerInstances.size()); + for (auto [laneIndex, producerInstance] : llvm::enumerate(tensorInfo.producerInstances)) { + if (tensorInfo.sourceCoreIds[laneIndex] == static_cast(cpu)) { + auto producedIt = producedValuesByTask.find(producerInstance); + if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= tensorInfo.resultIndex) { + task.computeInstance.op->emitOpError( + "missing local tensor lane producer during merge materialization") + << " consumerCpu=" << cpu << " producerLaneStart=" << producerInstance.laneStart; + return failure(); + } + laneValues.push_back(producedIt->second[tensorInfo.resultIndex]); + continue; + } + + auto producerTaskIt = taskByComputeInstance.find(producerInstance); + if (producerTaskIt == taskByComputeInstance.end()) + return failure(); + Type laneType = getTaskOutputTypes(producerTaskIt->second.computeInstance)[tensorInfo.resultIndex]; + Value channelId = createIndexConstant(program.op, tensorInfo.channelIds[laneIndex], constantFolder); + Value sourceCoreId = createIndexConstant(program.op, tensorInfo.sourceCoreIds[laneIndex], constantFolder); + Value targetCoreId = createIndexConstant(program.op, tensorInfo.targetCoreIds[laneIndex], constantFolder); + auto receive = + spatial::SpatChannelReceiveOp::create(rewriter, loc, laneType, channelId, sourceCoreId, targetCoreId); + laneValues.push_back(receive.getResult()); + } + + Value packedInput = tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, ValueRange(laneValues)).getResult(); + resolvedInputs.push_back(packedInput); + continue; + } + + auto producerRef = getProducerValueRef(input, &task.computeInstance); if (producerRef) { auto producerIt = taskByComputeInstance.find(producerRef->instance); if (producerIt != taskByComputeInstance.end()) { @@ -421,13 +1006,11 @@ private: resolvedInputs.push_back(*received); continue; } - auto receive = - spatial::SpatChannelReceiveOp::create(rewriter, - loc, - input.getType(), - rewriter.getI64IntegerAttr(channelInfo.channelId), - rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(channelInfo.targetCoreId)); + Value channelId = createIndexConstant(program.op, channelInfo.channelId, constantFolder); + Value sourceCoreId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder); + Value targetCoreId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder); + auto receive = spatial::SpatChannelReceiveOp::create( + rewriter, loc, input.getType(), channelId, sourceCoreId, targetCoreId); resolvedInputs.push_back(receive.getResult()); continue; } @@ -439,8 +1022,11 @@ private: rewriter.setInsertionPointToEnd(&program.op.getBody().front()); if (isa(task.computeInstance.op)) { IRMapping mapper; - for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments())) - mapper.map(oldArg, resolvedInputs[argIndex]); + auto compute = cast(task.computeInstance.op); + for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) + mapper.map(compute.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight))); + for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs)) + mapper.map(compute.getInputArgument(inputIndex), input); for (Operation& op : templateBlock) { if (auto yield = dyn_cast(&op)) { @@ -449,50 +1035,90 @@ private: continue; } - Operation* clonedOp = rewriter.clone(op, mapper); - if (auto oldWeightedMvmOp = dyn_cast(&op)) { - auto newWeightedMvmOp = cast(clonedOp); - Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()]; - newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight)); - } - if (auto oldWeightedVmmOp = dyn_cast(&op)) { - auto newWeightedVmmOp = cast(clonedOp); - Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()]; - newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight)); - } + rewriter.clone(op, mapper); } } else { - for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) { + auto batch = cast(task.computeInstance.op); + if (batch.getNumResults() != 0) { IRMapping mapper; - if (templateBlock.getNumArguments() == 1) - mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]); - - for (Operation& op : templateBlock) { - if (auto yield = dyn_cast(&op)) { - for (Value yieldOperand : yield.getOperands()) - taskYieldValues.push_back(mapper.lookup(yieldOperand)); - continue; - } + Value laneValue = getOrCreateHostIndexConstant( + program.op, static_cast(task.computeInstance.laneStart), constantFolder); + mapper.map(batch.getLaneArgument(), laneValue); + for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) + mapper.map(batch.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight))); + for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs)) + mapper.map(batch.getInputArgument(inputIndex), input); + for (Operation& op : templateBlock.without_terminator()) { Operation* clonedOp = rewriter.clone(op, mapper); - if (auto oldWeightedMvmOp = dyn_cast(&op)) { - if (oldWeightedMvmOp.getWeightIndex() != 0) { - task.computeInstance.op->emitOpError( - "batched per-cpu merge materialization expects lane-local weight index 0"); - return failure(); - } - auto newWeightedMvmOp = cast(clonedOp); - newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) + mapper.map(oldResult, newResult); + } + + FailureOr> yieldInfo = collectResultfulBatchYieldInfo(batch); + if (failed(yieldInfo)) + return task.computeInstance.op->emitOpError("failed to collect resultful batch yield info"); + for (const BatchYieldInfo& info : *yieldInfo) + taskYieldValues.push_back(mapper.lookup(info.yieldedValue)); + } + else { + size_t batchLaneCount = static_cast(batch.getLaneCount()); + size_t inputsPerLane = batchLaneCount == 0 ? 0 : batch.getInputs().size() / batchLaneCount; + size_t weightsPerLane = batchLaneCount == 0 ? 0 : batch.getWeights().size() / batchLaneCount; + size_t loopRunLength = getLoopableResultlessBatchRunLength(programTasks, taskIndex); + if (loopRunLength > 1) { + Value lower = getOrCreateHostIndexConstant( + program.op, static_cast(task.computeInstance.laneStart), constantFolder); + Value upper = getOrCreateHostIndexConstant( + program.op, static_cast(task.computeInstance.laneStart + loopRunLength), constantFolder); + Value step = getOrCreateHostIndexConstant(program.op, 1, constantFolder); + auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {}); + rewriter.setInsertionPointToStart(loop.getBody()); + + IRMapping mapper; + mapper.map(batch.getLaneArgument(), loop.getInductionVar()); + for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) + mapper.map(batch.getWeightArgument(weightIndex), + program.op.getWeightArgument(program.weightToIndex.at(taskWeights[weightIndex]))); + for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex) + mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[inputIndex]); + + for (Operation& op : templateBlock) { + if (isa(&op)) + continue; + + Operation* clonedOp = rewriter.clone(op, mapper); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) + mapper.map(oldResult, newResult); } - if (auto oldWeightedVmmOp = dyn_cast(&op)) { - if (oldWeightedVmmOp.getWeightIndex() != 0) { - task.computeInstance.op->emitOpError( - "batched per-cpu merge materialization expects lane-local weight index 0"); - return failure(); + rewriter.setInsertionPointAfter(loop); + taskIndex += loopRunLength - 1; + } + else { + for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) { + IRMapping mapper; + Value laneValue = getOrCreateHostIndexConstant( + program.op, static_cast(task.computeInstance.laneStart + laneOffset), constantFolder); + mapper.map(batch.getLaneArgument(), laneValue); + for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) + mapper.map(batch.getWeightArgument(weightIndex), + program.op.getWeightArgument( + program.weightToIndex.at(taskWeights[laneOffset * weightsPerLane + weightIndex]))); + for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex) + mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[laneOffset * inputsPerLane + inputIndex]); + + for (Operation& op : templateBlock) { + if (auto yield = dyn_cast(&op)) { + for (Value yieldOperand : yield.getOperands()) + taskYieldValues.push_back(mapper.lookup(yieldOperand)); + continue; + } + + Operation* clonedOp = rewriter.clone(op, mapper); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) + mapper.map(oldResult, newResult); } - auto newWeightedVmmOp = cast(clonedOp); - newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); } } } @@ -500,25 +1126,36 @@ private: producedValuesByTask[task.computeInstance] = taskYieldValues; if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) { - for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) { + for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) { + const SmallVector& sendInfos = sendsIt->second[resultIndex]; if (sendInfos.empty()) + { + ++resultIndex; continue; + } + + size_t nextResultIndex = resultIndex + 1; + if (tryEmitCompactSendLoops( + program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) { + resultIndex = nextResultIndex; + continue; + } + Value producedValue = taskYieldValues[resultIndex]; for (const RemoteSendInfo& sendInfo : sendInfos) { - spatial::SpatChannelSendOp::create(rewriter, - loc, - rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId), - rewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId), - producedValue); + Value channelId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder); + Value sourceCoreId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder); + Value targetCoreId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder); + spatial::SpatChannelSendOp::create(rewriter, loc, channelId, sourceCoreId, targetCoreId, producedValue); } + ++resultIndex; } } } SmallVector yieldValues; - yieldValues.reserve(cpuExternalOutputs[cpu].size()); - for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { + yieldValues.reserve(cpuExternalOutputs[programKey].size()); + for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) { auto producedIt = producedValuesByTask.find(outputRef.instance); if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) { ScheduledTask task = taskByComputeInstance.at(outputRef.instance); @@ -540,6 +1177,33 @@ private: if (!oldComputeOps.contains(use.getOwner())) use.assign(newValue); } + + IRRewriter rewriter(func.getContext()); + for (auto& [op, resultLaneValues] : resultfulBatchLaneResults) { + auto batch = cast(op); + for (auto [resultIndex, laneValues] : llvm::enumerate(resultLaneValues)) { + bool hasNonScheduledUse = false; + for (Operation* user : batch.getResult(resultIndex).getUsers()) { + if (!oldComputeOps.contains(user)) { + hasNonScheduledUse = true; + break; + } + } + if (!hasNonScheduledUse) + continue; + + if (laneValues.size() != static_cast(batch.getLaneCount()) + || llvm::any_of(laneValues, [](Value value) { return !value; })) { + batch.emitOpError("missing materialized lane result while rebuilding resultful compute_batch result"); + continue; + } + + rewriter.setInsertionPoint(returnOp); + Value packedResult = + tensor::ConcatOp::create(rewriter, batch.getLoc(), /*dim=*/0, ValueRange(laneValues)).getResult(); + batch.getResult(resultIndex).replaceAllUsesWith(packedResult); + } + } } LogicalResult eraseOldScheduledOps() { @@ -573,25 +1237,31 @@ private: int64_t* nextChannelId = nullptr; Location loc; func::ReturnOp returnOp; + OperationFolder constantFolder; SmallVector scheduledTasks; DenseSet oldComputeOps; DenseMap taskByComputeInstance; DenseMap> tasksByCpu; + DenseMap> tasksByProgram; SmallVector orderedCpus; + SmallVector orderedPrograms; DenseSet seenCpus; + DenseSet seenPrograms; DenseMap>> remoteSendsByTask; DenseMap>> remoteInputsByTask; - DenseMap> cpuExternalInputs; - DenseMap> cpuWeights; - DenseMap> cpuExternalOutputs; - DenseMap> seenExternalInputsByCpu; - DenseMap> seenWeightsByCpu; + DenseMap, 4>> remoteTensorInputsByTask; + DenseMap> cpuExternalInputs; + DenseMap> cpuWeights; + DenseMap> cpuExternalOutputs; + DenseMap> seenExternalInputsByProgram; + DenseMap> seenWeightsByProgram; DenseSet pairsNeedingReceiveReorder; DenseMap>> receiveQueuesByCpu; - DenseMap cpuPrograms; + DenseMap cpuPrograms; DenseMap oldToNewExternalValueMap; DenseMap> producedValuesByTask; + DenseMap>> resultfulBatchLaneResults; }; } // namespace diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 92016c2..7e503e1 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -167,21 +167,20 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) { return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size(); } -SmallVector appendMissingWeightsAndBuildIndexMap(SpatCompute target, ValueRange sourceWeights) { +SmallVector appendMissingWeightsAndBuildIndexMap(SmallVectorImpl& targetWeights, ValueRange sourceWeights) { DenseMap> targetWeightIndices; - for (auto [weightIndex, weight] : llvm::enumerate(target.getWeights())) + for (auto [weightIndex, weight] : llvm::enumerate(targetWeights)) targetWeightIndices[weight].push_back(weightIndex); DenseMap usedSourceWeightOccurrences; SmallVector sourceToTargetIndex; sourceToTargetIndex.reserve(sourceWeights.size()); - auto targetWeights = target.getWeightsMutable(); for (Value weight : sourceWeights) { size_t occurrence = usedSourceWeightOccurrences[weight]++; auto& matchingIndices = targetWeightIndices[weight]; if (occurrence >= matchingIndices.size()) { - size_t newIndex = target.getWeights().size(); - targetWeights.append(weight); + size_t newIndex = targetWeights.size(); + targetWeights.push_back(weight); matchingIndices.push_back(newIndex); sourceToTargetIndex.push_back(newIndex); continue; @@ -213,37 +212,36 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) { auto& computeUse = *compute->getUses().begin(); auto child = cast(computeUse.getOwner()); auto usedResult = cast(computeUse.get()).getResultNumber(); - auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size(); + auto childInputIndex = computeUse.getOperandNumber() - child.getWeights().size(); rewriter.setInsertionPointAfter(compute.getOperation()); - auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); - newCompute.getProperties().setOperandSegmentSizes( - {static_cast(compute.getWeights().size()), static_cast(compute.getInputs().size())}); + SmallVector mergedWeights(compute.getWeights().begin(), compute.getWeights().end()); + SmallVector childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(mergedWeights, child.getWeights()); + SmallVector mergedInputs(compute.getInputs().begin(), compute.getInputs().end()); + auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), mergedWeights, mergedInputs); + Block* newBody = rewriter.createBlock(&newCompute.getBodyRegion()); + for (Value weight : mergedWeights) + newBody->addArgument(weight.getType(), loc); + for (Value input : mergedInputs) + newBody->addArgument(input.getType(), loc); IRMapping mapper; - SmallVector childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(newCompute, child.getWeights()); + for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) + mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex)); + for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) + mapper.map(compute.getInputArgument(inputIndex), newCompute.getInputArgument(inputIndex)); for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) - mapper.map(weight, *std::next(newCompute.getWeights().begin(), childWeightToNewIndex[oldIndex])); + mapper.map(child.getWeightArgument(oldIndex), newCompute.getWeightArgument(childWeightToNewIndex[oldIndex])); - compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); - auto newTerminator = newCompute.getBody().front().getTerminator(); - mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult)); - newTerminator->erase(); + rewriter.setInsertionPointToEnd(newBody); + auto computeYield = cast(compute.getBody().front().getTerminator()); + for (Operation& op : compute.getBody().front().without_terminator()) + rewriter.clone(op, mapper); + mapper.map(child.getInputArgument(childInputIndex), mapper.lookupOrDefault(computeYield.getOperand(usedResult))); - rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); - auto remapWeightIndex = [&](auto weightedOp) { - auto oldIndex = weightedOp.getWeightIndex(); - assert(static_cast(oldIndex) < childWeightToNewIndex.size() && "weight index out of range"); - weightedOp.setWeightIndex(childWeightToNewIndex[oldIndex]); - }; - - for (auto& op : child.getBody().front()) { - auto newInst = rewriter.clone(op, mapper); - if (auto weightedMvmOp = dyn_cast(newInst)) - remapWeightIndex(weightedMvmOp); - if (auto weightedVmmOp = dyn_cast(newInst)) - remapWeightIndex(weightedVmmOp); - } + rewriter.setInsertionPointToEnd(newBody); + for (auto& op : child.getBody().front()) + rewriter.clone(op, mapper); child.replaceAllUsesWith(newCompute); toErase.insert(child); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp index 47b1094..c3bae1d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/DenseMap.h" @@ -61,6 +62,66 @@ std::optional getComputeCoreId(SpatCompute compute) { static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase"; +static FailureOr getConstantI64Value(Value value) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + return constantValue.getSExtValue(); +} + +static FailureOr getConstantI32Value(Value value) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + return static_cast(constantValue.getSExtValue()); +} + +static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op, + uint64_t& channelId, + uint32_t& sourceCoreId, + uint32_t& targetCoreId) { + FailureOr constantChannelId = getConstantI64Value(op.getChannelId()); + FailureOr constantSourceCoreId = getConstantI32Value(op.getSourceCoreId()); + FailureOr constantTargetCoreId = getConstantI32Value(op.getTargetCoreId()); + if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId)) + return false; + channelId = static_cast(*constantChannelId); + sourceCoreId = static_cast(*constantSourceCoreId); + targetCoreId = static_cast(*constantTargetCoreId); + return true; +} + +static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op, + uint64_t& channelId, + uint32_t& sourceCoreId, + uint32_t& targetCoreId) { + FailureOr constantChannelId = getConstantI64Value(op.getChannelId()); + FailureOr constantSourceCoreId = getConstantI32Value(op.getSourceCoreId()); + FailureOr constantTargetCoreId = getConstantI32Value(op.getTargetCoreId()); + if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId)) + return false; + channelId = static_cast(*constantChannelId); + sourceCoreId = static_cast(*constantSourceCoreId); + targetCoreId = static_cast(*constantTargetCoreId); + return true; +} + +static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { + SmallVector constants; + constants.reserve(values.size()); + for (int64_t value : values) + constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder)); + return constants; +} + +static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { + SmallVector constants; + constants.reserve(values.size()); + for (int32_t value : values) + constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder)); + return constants; +} + std::optional getComputeRebatchPhase(SpatCompute compute) { if (auto phaseAttr = compute->getAttrOfType(kRebatchPhaseAttrName)) return static_cast(phaseAttr.getInt()); @@ -206,8 +267,215 @@ bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); } +struct BatchYieldInfo { + Value yieldedValue; + tensor::ParallelInsertSliceOp insertSlice; +}; + +static bool isHostOnlyBatchResultUser(Operation* user) { + return isa(user); +} + +static FailureOr> collectBatchYieldInfo(SpatComputeBatch batchOp) { + Block& block = batchOp.getBody().front(); + auto inParallel = dyn_cast(block.getTerminator()); + if (!inParallel) + return failure(); + + DenseMap batchYieldByOutputArg; + for (Operation& op : inParallel.getRegion().front()) { + auto insertSlice = dyn_cast(&op); + if (!insertSlice) + return failure(); + auto outputArg = dyn_cast(insertSlice.getDest()); + if (!outputArg || outputArg.getOwner() != &block) + return failure(); + batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice}; + } + return batchYieldByOutputArg; +} + +static FailureOr cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) { + auto coreIdsAttr = batchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); + if (!coreIdsAttr) + return failure(); + + Block& oldBlock = batchOp.getBody().front(); + rewriter.setInsertionPoint(batchOp); + auto newBatch = SpatComputeBatch::create(rewriter, + batchOp.getLoc(), + TypeRange {}, + rewriter.getI32IntegerAttr(batchOp.getLaneCount()), + batchOp.getWeights(), + batchOp.getInputs()); + newBatch.getProperties().setOperandSegmentSizes( + {static_cast(batchOp.getWeights().size()), static_cast(batchOp.getInputs().size())}); + newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr); + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size()); + blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size()); + blockArgTypes.push_back(batchOp.getLaneArgument().getType()); + blockArgLocs.push_back(batchOp.getLaneArgument().getLoc()); + for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) { + blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType()); + blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc()); + } + for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) { + blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType()); + blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc()); + } + + Block* newBlock = + rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + rewriter.setInsertionPointToStart(newBlock); + + IRMapping mapper; + mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument()); + for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) + mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex)); + for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) + mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex)); + + for (Operation& op : oldBlock.without_terminator()) { + Operation* cloned = rewriter.clone(op, mapper); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(oldResult, newResult); + } + + return newBatch; +} + +static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) { + IRRewriter rewriter(funcOp.getContext()); + OperationFolder constantFolder(funcOp.getContext()); + SmallVector batches(funcOp.getOps()); + + for (auto batchOp : batches) { + if (batchOp.getNumResults() == 0) + continue; + + auto coreIdsAttr = batchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); + if (!coreIdsAttr) + return batchOp.emitOpError("missing coreIds while materializing batch result communication"); + + FailureOr> batchYieldInfo = collectBatchYieldInfo(batchOp); + if (failed(batchYieldInfo)) + return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body"); + + FailureOr newBatch = cloneBatchAsResultless(batchOp, rewriter); + if (failed(newBatch)) + return batchOp.emitOpError("failed to clone resultful compute_batch as resultless"); + + Block& oldBlock = batchOp.getBody().front(); + Block& newBlock = newBatch->getBody().front(); + IRMapping mapper; + mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument()); + for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) + mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex)); + for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) + mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex)); + auto oldIt = oldBlock.begin(); + auto newIt = newBlock.begin(); + for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt) + for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults())) + mapper.map(oldResult, newResult); + + SmallVector sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); + rewriter.setInsertionPointToEnd(&newBlock); + + for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) { + BlockArgument outputArg = batchOp.getOutputArgument(resultIndex); + auto yieldInfoIt = batchYieldInfo->find(outputArg); + if (yieldInfoIt == batchYieldInfo->end()) + return batchOp.emitOpError( + "missing yielded value for compute_batch result during communication materialization"); + Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue); + + DenseMap> computeUsesByTargetCore; + SmallVector hostUses; + for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) { + if (auto computeOp = dyn_cast(use.getOwner())) { + auto coreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName); + if (!coreIdAttr) + return batchOp.emitOpError("compute user of compute_batch result is missing coreId"); + computeUsesByTargetCore[static_cast(coreIdAttr.getInt())].push_back(&use); + continue; + } + if (isHostOnlyBatchResultUser(use.getOwner())) { + hostUses.push_back(&use); + continue; + } + return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization") + << ": " << use.getOwner()->getName(); + } + + auto createReceiveForUses = [&](ArrayRef uses, ArrayRef targetCoreIds) -> LogicalResult { + if (uses.empty()) + return success(); + + SmallVector channelIds; + channelIds.reserve(sourceCoreIds.size()); + for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds) + channelIds.push_back(nextChannelId++); + SmallVector sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder); + SmallVector sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder); + SmallVector sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder); + + spatial::SpatChannelSendBatchOp::create(rewriter, + batchOp.getLoc(), + sendChannelIdValues, + sendSourceCoreIdValues, + sendTargetCoreIdValues, + mappedYieldedValue); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(newBatch->getOperation()); + SmallVector receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder); + SmallVector receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder); + SmallVector receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder); + auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter, + batchOp.getLoc(), + batchOp.getResult(resultIndex).getType(), + receiveChannelIdValues, + receiveSourceCoreIdValues, + receiveTargetCoreIdValues); + for (OpOperand* use : uses) + use->set(received.getOutput()); + rewriter.setInsertionPointToEnd(&newBlock); + return success(); + }; + + for (auto& [targetCoreId, uses] : computeUsesByTargetCore) { + SmallVector targetCoreIds(static_cast(batchOp.getLaneCount()), targetCoreId); + if (failed(createReceiveForUses(uses, targetCoreIds))) + return failure(); + } + + if (!hostUses.empty()) { + SmallVector hostTargetCoreIds(static_cast(batchOp.getLaneCount()), 0); + if (failed(createReceiveForUses(hostUses, hostTargetCoreIds))) + return failure(); + } + } + + rewriter.setInsertionPointToEnd(&newBlock); + spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {}); + rewriter.eraseOp(batchOp); + } + + return success(); +} + void rebatchEquivalentComputes(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); + OperationFolder constantFolder(funcOp.getContext()); SmallVector computes(funcOp.getOps()); DenseSet consumed; DenseMap computeOrder; @@ -316,8 +584,10 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) { entries.reserve(group.size()); for (auto [groupIndex, compute] : llvm::enumerate(group)) { auto groupReceive = cast(&*opIts[groupIndex]); - entries.push_back( - {groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()}); + BatchReceiveEntry entry; + if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId)) + return; + entries.push_back(entry); ++opIts[groupIndex]; } SmallVector channelIds; @@ -331,12 +601,15 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) { sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); targetCoreIds.push_back(static_cast(entry.targetCoreId)); } + SmallVector channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder); + SmallVector sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder); + SmallVector targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder); auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter, receiveOp.getLoc(), receiveOp.getOutput().getType(), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds)); + channelIdValues, + sourceCoreIdValues, + targetCoreIdValues); mapper.map(receiveOp.getOutput(), batchReceive.getOutput()); continue; } @@ -351,7 +624,10 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) { entries.reserve(group.size()); for (auto [groupIndex, compute] : llvm::enumerate(group)) { auto groupSend = cast(&*opIts[groupIndex]); - entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()}); + BatchSendEntry entry; + if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId)) + return; + entries.push_back(entry); ++opIts[groupIndex]; } SmallVector channelIds; @@ -365,11 +641,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) { sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); targetCoreIds.push_back(static_cast(entry.targetCoreId)); } + SmallVector channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder); + SmallVector sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder); + SmallVector targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder); spatial::SpatChannelSendBatchOp::create(rewriter, sendOp.getLoc(), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds), + channelIdValues, + sourceCoreIdValues, + targetCoreIdValues, mapper.lookup(sendOp.getInput())); continue; } @@ -452,6 +731,11 @@ LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextC ScopedMergePhaseTimer timer("cleanup-dead-packing-ops"); cleanupDeadPackingOps(funcOp); } + { + ScopedMergePhaseTimer timer("materialize-batch-result-communication"); + if (failed(materializeBatchResultCommunication(funcOp, nextChannelId))) + return failure(); + } return success(); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp index d98ccd4..e56a36a 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -3,6 +3,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" @@ -30,7 +31,7 @@ enum class RegularStepKind { struct RegularStep { RegularStepKind kind; - int32_t weightIndex = 0; + Value weight; Value invariantOperand; Type resultType; }; @@ -73,15 +74,90 @@ static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) { return (static_cast(sourceCoreId) << 32) | static_cast(targetCoreId); } -static void appendChannelAttrs(SmallVectorImpl& channelIds, - SmallVectorImpl& sourceCoreIds, - SmallVectorImpl& targetCoreIds, - uint64_t channelId, - uint32_t sourceCoreId, - uint32_t targetCoreId) { - channelIds.push_back(static_cast(channelId)); - sourceCoreIds.push_back(static_cast(sourceCoreId)); - targetCoreIds.push_back(static_cast(targetCoreId)); +static FailureOr getConstantI64Value(Value value) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + return constantValue.getSExtValue(); +} + +static FailureOr getConstantI32Value(Value value) { + APInt constantValue; + if (!matchPattern(value, m_ConstantInt(&constantValue))) + return failure(); + return static_cast(constantValue.getSExtValue()); +} + +static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op, + uint64_t& channelId, + uint32_t& sourceCoreId, + uint32_t& targetCoreId) { + FailureOr constantChannelId = getConstantI64Value(op.getChannelId()); + FailureOr constantSourceCoreId = getConstantI32Value(op.getSourceCoreId()); + FailureOr constantTargetCoreId = getConstantI32Value(op.getTargetCoreId()); + if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId)) + return false; + channelId = static_cast(*constantChannelId); + sourceCoreId = static_cast(*constantSourceCoreId); + targetCoreId = static_cast(*constantTargetCoreId); + return true; +} + +static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op, + uint64_t& channelId, + uint32_t& sourceCoreId, + uint32_t& targetCoreId) { + FailureOr constantChannelId = getConstantI64Value(op.getChannelId()); + FailureOr constantSourceCoreId = getConstantI32Value(op.getSourceCoreId()); + FailureOr constantTargetCoreId = getConstantI32Value(op.getTargetCoreId()); + if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId)) + return false; + channelId = static_cast(*constantChannelId); + sourceCoreId = static_cast(*constantSourceCoreId); + targetCoreId = static_cast(*constantTargetCoreId); + return true; +} + +static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { + SmallVector constants; + constants.reserve(values.size()); + for (int64_t value : values) + constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder)); + return constants; +} + +static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { + SmallVector constants; + constants.reserve(values.size()); + for (int32_t value : values) + constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder)); + return constants; +} + +static SmallVector getScalarChannelMetadataDefs(Operation* channelOp, unsigned metadataOperandCount) { + SmallVector defs; + defs.reserve(metadataOperandCount); + for (unsigned operandIndex = 0; operandIndex < metadataOperandCount; ++operandIndex) { + Operation* def = channelOp->getOperand(operandIndex).getDefiningOp(); + auto constantOp = dyn_cast_or_null(def); + if (!constantOp || def->getBlock() != channelOp->getBlock()) + continue; + defs.push_back(def); + } + llvm::sort(defs, [](Operation* lhs, Operation* rhs) { return lhs->isBeforeInBlock(rhs); }); + return defs; +} + +static void moveScalarChannelBundleBefore(Operation* channelOp, Operation* insertionPoint) { + for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3)) + metadataDef->moveBefore(insertionPoint); + channelOp->moveBefore(insertionPoint); +} + +static void moveScalarChannelBundleBefore(Operation* channelOp, Block* block, Block::iterator insertionPoint) { + for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3)) + metadataDef->moveBefore(block, insertionPoint); + channelOp->moveBefore(block, insertionPoint); } static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) { @@ -196,7 +272,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter } static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { - return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand + return lhs.kind == rhs.kind && lhs.weight == rhs.weight && lhs.invariantOperand == rhs.invariantOperand && lhs.resultType == rhs.resultType; } @@ -227,8 +303,7 @@ static FailureOr analyzeRegularChunk(spatial::SpatVMMOp startOp) { chunk.input = startOp.getInput(); chunk.output = startOp.getOutput(); chunk.ops.push_back(startOp.getOperation()); - chunk.steps.push_back( - {RegularStepKind::Wvmm, static_cast(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()}); + chunk.steps.push_back({RegularStepKind::Wvmm, startOp.getWeight(), Value(), startOp.getOutput().getType()}); Value currentValue = startOp.getOutput(); while (currentValue.hasOneUse()) { @@ -241,9 +316,9 @@ static FailureOr analyzeRegularChunk(spatial::SpatVMMOp startOp) { break; if (vaddOp.getLhs() == currentValue) - chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()}); + chunk.steps.push_back({RegularStepKind::VAddLhs, Value(), vaddOp.getRhs(), vaddOp.getOutput().getType()}); else if (vaddOp.getRhs() == currentValue) - chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()}); + chunk.steps.push_back({RegularStepKind::VAddRhs, Value(), vaddOp.getLhs(), vaddOp.getOutput().getType()}); else break; @@ -255,7 +330,8 @@ static FailureOr analyzeRegularChunk(spatial::SpatVMMOp startOp) { return chunk; } -static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef run) { +static RegularCompactionResult +compactRegularChunkRun(IRRewriter& rewriter, ArrayRef run, OperationFolder& constantFolder) { assert(!run.empty() && "expected a non-empty regular chunk run"); const RegularChunk& anchorChunk = run.front(); RegularCompactionResult result; @@ -275,9 +351,9 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra auto packedOutputType = getPackedTensorType(outputType, static_cast(run.size())); auto packedInit = tensor::EmptyOp::create( rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType()); - auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0); - auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size()); - auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1); + auto zero = getOrCreateHostIndexConstant(anchorChunk.startOp, 0, constantFolder); + auto upper = getOrCreateHostIndexConstant(anchorChunk.startOp, static_cast(run.size()), constantFolder); + auto step = getOrCreateHostIndexConstant(anchorChunk.startOp, 1, constantFolder); auto loop = scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()}); @@ -290,8 +366,7 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra Value inputRowOffset = iv; if (inputType.getDimSize(0) != 1) { - auto rowsPerValue = - arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0)); + auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, inputType.getDimSize(0), constantFolder); inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue); } @@ -320,8 +395,7 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra Value mappedOutput = mapping.lookup(anchorChunk.output); Value outputRowOffset = iv; if (outputType.getDimSize(0) != 1) { - auto rowsPerValue = - arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0)); + auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, outputType.getDimSize(0), constantFolder); outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue); } @@ -389,35 +463,50 @@ void orderBilateralChannelOps(func::FuncOp funcOp) { Block& block = compute.getBody().front(); SmallVector> moves; DenseMap firstForwardedSendByEndpoint; + Operation* firstForwardedSend = nullptr; for (Operation& op : block) { if (auto sendOp = dyn_cast(&op)) { - if (sendOp.getSourceCoreId() == static_cast(coreId) - && isForwardedChannelPayload(sendOp.getInput(), block)) { - uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId()); + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + if (getScalarChannelMetadata(sendOp, channelId, sourceCoreId, targetCoreId) + && sourceCoreId == static_cast(coreId) && isForwardedChannelPayload(sendOp.getInput(), block)) { + if (!firstForwardedSend) + firstForwardedSend = sendOp.getOperation(); + uint64_t key = getEndpointKey(sourceCoreId, targetCoreId); firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation()); } continue; } auto receiveOp = dyn_cast(&op); - if (!receiveOp || receiveOp.getTargetCoreId() != static_cast(coreId) - || receiveOp.getSourceCoreId() >= static_cast(coreId)) { + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId) + || targetCoreId != static_cast(coreId) || sourceCoreId >= static_cast(coreId)) { continue; } - uint64_t key = getEndpointKey(static_cast(coreId), receiveOp.getSourceCoreId()); + uint64_t key = getEndpointKey(static_cast(coreId), sourceCoreId); auto firstMatchingSend = firstForwardedSendByEndpoint.find(key); if (firstMatchingSend != firstForwardedSendByEndpoint.end()) moves.push_back({receiveOp, firstMatchingSend->second}); + else if (firstForwardedSend && firstForwardedSend->isBeforeInBlock(receiveOp)) + moves.push_back({receiveOp, firstForwardedSend}); } for (auto [receiveOp, insertionPoint] : moves) - receiveOp->moveBefore(insertionPoint); + moveScalarChannelBundleBefore(receiveOp, insertionPoint); for (auto it = block.begin(); it != block.end();) { auto receiveOp = dyn_cast(&*it); - if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast(coreId)) { + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId) + || sourceCoreId >= static_cast(coreId)) { ++it; continue; } @@ -425,18 +514,32 @@ void orderBilateralChannelOps(func::FuncOp funcOp) { Type outputType = receiveOp.getOutput().getType(); auto run = collectConsecutiveRun( it, block.end(), [&](spatial::SpatChannelReceiveOp current) { + uint64_t currentChannelId = 0; + uint32_t currentSourceCoreId = 0; + uint32_t currentTargetCoreId = 0; return current.getOutput().getType() == outputType - && current.getSourceCoreId() < static_cast(coreId); + && getScalarChannelMetadata(current, currentChannelId, currentSourceCoreId, currentTargetCoreId) + && currentSourceCoreId < static_cast(coreId); }); if (run.ops.size() > 1) { SmallVector sorted(run.ops); llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) { - return lhs.getSourceCoreId() > rhs.getSourceCoreId(); + uint64_t lhsChannelId = 0; + uint32_t lhsSourceCoreId = 0; + uint32_t lhsTargetCoreId = 0; + uint64_t rhsChannelId = 0; + uint32_t rhsSourceCoreId = 0; + uint32_t rhsTargetCoreId = 0; + bool lhsHasMetadata = getScalarChannelMetadata(lhs, lhsChannelId, lhsSourceCoreId, lhsTargetCoreId); + bool rhsHasMetadata = getScalarChannelMetadata(rhs, rhsChannelId, rhsSourceCoreId, rhsTargetCoreId); + if (!lhsHasMetadata || !rhsHasMetadata) + return false; + return lhsSourceCoreId > rhsSourceCoreId; }); Block::iterator insertIt = run.end; for (auto op : sorted) - op->moveBefore(&block, insertIt); + moveScalarChannelBundleBefore(op, &block, insertIt); } it = run.end; @@ -446,6 +549,7 @@ void orderBilateralChannelOps(func::FuncOp funcOp) { void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { IRRewriter rewriter(funcOp.getContext()); + OperationFolder constantFolder(funcOp.getContext()); for (auto compute : funcOp.getOps()) { Block& block = compute.getBody().front(); @@ -461,7 +565,14 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { bool hasRepeatedEndpoint = false; DenseSet seenEndpoints; for (auto op : run.ops) { - uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId()); + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) { + hasRepeatedEndpoint = true; + break; + } + uint64_t endpointKey = getEndpointKey(sourceCoreId, targetCoreId); if (!seenEndpoints.insert(endpointKey).second) { hasRepeatedEndpoint = true; break; @@ -478,8 +589,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { }; SmallVector sortedEntries; sortedEntries.reserve(run.ops.size()); - for (auto [originalIndex, op] : llvm::enumerate(run.ops)) - sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); + for (auto [originalIndex, op] : llvm::enumerate(run.ops)) { + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) { + sortedEntries.clear(); + break; + } + sortedEntries.push_back({op, originalIndex, sourceCoreId, targetCoreId, channelId}); + } + if (sortedEntries.empty()) { + ++it; + continue; + } SmallVector channelIds; SmallVector sourceCoreIds; @@ -488,8 +611,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { sourceCoreIds.reserve(sortedEntries.size()); targetCoreIds.reserve(sortedEntries.size()); for (ReceiveEntry& entry : sortedEntries) { - appendChannelAttrs( - channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId); + channelIds.push_back(static_cast(entry.channelId)); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); } auto rowType = cast(run.ops.front().getOutput().getType()); @@ -506,13 +630,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { : RankedTensorType {}; auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; rewriter.setInsertionPoint(run.ops.front()); - auto compactReceive = - spatial::SpatChannelReceiveTensorOp::create(rewriter, - run.ops.front().getLoc(), - packedType, - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds)); + SmallVector channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder); + SmallVector sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder); + SmallVector targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder); + auto compactReceive = spatial::SpatChannelReceiveTensorOp::create( + rewriter, run.ops.front().getLoc(), packedType, channelIdValues, sourceCoreIdValues, targetCoreIdValues); if (concatOp && concatPackedType) { replaceConcatRunWithPackedValue(concatOp, concatStartIndex, @@ -551,8 +673,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { }; SmallVector sortedEntries; sortedEntries.reserve(run.ops.size()); - for (auto op : run.ops) - sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); + for (auto op : run.ops) { + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) { + sortedEntries.clear(); + break; + } + sortedEntries.push_back({op, sourceCoreId, targetCoreId, channelId}); + } + if (sortedEntries.empty()) { + ++it; + continue; + } SmallVector channelIds; SmallVector sourceCoreIds; @@ -563,20 +697,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { targetCoreIds.reserve(sortedEntries.size()); inputs.reserve(sortedEntries.size()); for (SendEntry& entry : sortedEntries) { - appendChannelAttrs( - channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId); + channelIds.push_back(static_cast(entry.channelId)); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); inputs.push_back(entry.op.getInput()); } rewriter.setInsertionPoint(run.ops.front()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); if (packedInput) { - spatial::SpatChannelSendTensorOp::create(rewriter, - run.ops.front().getLoc(), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds), - packedInput); + SmallVector channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder); + SmallVector sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder); + SmallVector targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder); + spatial::SpatChannelSendTensorOp::create( + rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput); for (auto op : run.ops) rewriter.eraseOp(op); @@ -606,9 +740,9 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { }); if (run.ops.size() > 1) { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; for (auto op : run.ops) { llvm::append_range(channelIds, op.getChannelIds()); llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); @@ -629,13 +763,8 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { : RankedTensorType {}; auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; rewriter.setInsertionPoint(run.ops.front()); - auto compactReceive = - spatial::SpatChannelReceiveTensorBatchOp::create(rewriter, - run.ops.front().getLoc(), - packedType, - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds)); + auto compactReceive = spatial::SpatChannelReceiveTensorBatchOp::create( + rewriter, run.ops.front().getLoc(), packedType, channelIds, sourceCoreIds, targetCoreIds); if (concatOp && concatPackedType) { replaceConcatRunWithPackedValue( concatOp, concatStartIndex, static_cast(outputs.size()), compactReceive.getOutput(), rewriter); @@ -663,9 +792,9 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { }); if (run.ops.size() > 1) { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; SmallVector inputs; inputs.reserve(run.ops.size()); for (auto op : run.ops) { @@ -678,12 +807,8 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { rewriter.setInsertionPoint(run.ops.front()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); if (packedInput) { - spatial::SpatChannelSendTensorBatchOp::create(rewriter, - run.ops.front().getLoc(), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds), - packedInput); + spatial::SpatChannelSendTensorBatchOp::create( + rewriter, run.ops.front().getLoc(), channelIds, sourceCoreIds, targetCoreIds, packedInput); for (auto op : run.ops) rewriter.eraseOp(op); @@ -700,6 +825,7 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { void compactRegularOpRuns(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); + OperationFolder constantFolder(funcOp.getContext()); auto compactInBlock = [&](Block& block) { for (auto it = block.begin(); it != block.end();) { @@ -740,7 +866,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) { for (const RegularChunk& chunk : run) originalOpCount += chunk.ops.size(); - RegularCompactionResult result = compactRegularChunkRun(rewriter, run); + RegularCompactionResult result = compactRegularChunkRun(rewriter, run, constantFolder); if (result.changed) { assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run"); if (!result.resumeAfter) { @@ -763,6 +889,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) { void compactRowWiseWvmmRuns(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); + OperationFolder constantFolder(funcOp.getContext()); for (auto compute : funcOp.getOps()) { Block& block = compute.getBody().front(); @@ -784,7 +911,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) { int64_t expectedRow = static_cast(rowResult.getResultNumber()); auto run = collectConsecutiveRun(it, block.end(), [&](spatial::SpatVMMOp current) { - if (current.getWeightIndex() != wvmmOp.getWeightIndex() + if (current.getWeight() != wvmmOp.getWeight() || current.getInput().getDefiningOp() != extractRowsOp || current.getInput().getType() != wvmmOp.getInput().getType() || current.getOutput().getType() != wvmmOp.getOutput().getType()) @@ -851,9 +978,9 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) { auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType()); rewriter.setInsertionPoint(run.ops.front()); - auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0); - auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength); - auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1); + auto zero = getOrCreateHostIndexConstant(run.ops.front(), 0, constantFolder); + auto upper = getOrCreateHostIndexConstant(run.ops.front(), runLength, constantFolder); + auto step = getOrCreateHostIndexConstant(run.ops.front(), 1, constantFolder); auto packedInit = tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType()); auto loop = @@ -868,7 +995,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) { Value sourceRow = iv; if (firstRow != 0) { - auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow); + auto firstRowValue = getOrCreateHostIndexConstant(run.ops.front(), firstRow, constantFolder); sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue); } @@ -883,7 +1010,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) { extractSizes, extractStrides); auto loopWvmm = spatial::SpatVMMOp::create( - rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult()); + rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeight(), extractedRow.getResult()); SmallVector insertOffsets = {iv, rewriter.getIndexAttr(0)}; SmallVector insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)}; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp index 65e0847..b3831f6 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp @@ -23,31 +23,31 @@ using namespace mlir; namespace { -Weight getComputeBodyWeight(Region &body) { +Weight getComputeBodyWeight(Region& body) { constexpr Weight kOperationWeight = 100; Weight numOperations = 0; - for (auto &block : body) - for ([[maybe_unused]] auto &op : block) + for (auto& block : body) + for ([[maybe_unused]] auto& op : block) numOperations = checkedAdd(numOperations, static_cast(1)); return checkedMultiply(numOperations, kOperationWeight); } -CrossbarUsage getComputeBodyCrossbarUsage(Region &body) { +CrossbarUsage getComputeBodyCrossbarUsage(Region& body) { CrossbarUsage crossbarUsage = 0; - for (auto &block : body) - for (auto &op : block) + for (auto& block : body) + for (auto& op : block) if (isa(op)) crossbarUsage = checkedAdd(crossbarUsage, static_cast(1)); return crossbarUsage; } -bool isUsedAsWeightOnly(Operation *producerOp) { +bool isUsedAsWeightOnly(Operation* producerOp) { if (producerOp->getNumResults() == 0) return false; for (Value result : producerOp->getResults()) { if (result.use_empty()) return false; - for (Operation *user : result.getUsers()) { + for (Operation* user : result.getUsers()) { if (auto compute = dyn_cast(user)) { if (!llvm::is_contained(compute.getWeights(), result)) return false; @@ -66,7 +66,7 @@ bool isUsedAsWeightOnly(Operation *producerOp) { std::vector aggregateEdges(llvm::ArrayRef edges) { llvm::DenseMap, Weight> edgeWeights; - for (const ComputeGraphEdge &edge : edges) { + for (const ComputeGraphEdge& edge : edges) { if (edge.source == edge.target) continue; auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost); @@ -76,9 +76,9 @@ std::vector aggregateEdges(llvm::ArrayRef ed std::vector aggregatedEdges; aggregatedEdges.reserve(edgeWeights.size()); - for (const auto &[key, weight] : edgeWeights) + for (const auto& [key, weight] : edgeWeights) aggregatedEdges.push_back({key.first, key.second, weight}); - llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) { + llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) { if (lhs.source != rhs.source) return lhs.source < rhs.source; return lhs.target < rhs.target; @@ -88,33 +88,33 @@ std::vector aggregateEdges(llvm::ArrayRef ed } // namespace -Weight getComputeInstanceWeight(const ComputeInstance &instance) { +Weight getComputeInstanceWeight(const ComputeInstance& instance) { if (auto spatCompute = dyn_cast(instance.op)) return getSpatComputeWeight(spatCompute); auto batch = cast(instance.op); return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast(instance.laneCount)); } -CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) { +CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) { if (auto spatCompute = dyn_cast(instance.op)) return getSpatComputeCrossbarUsage(spatCompute); auto batch = cast(instance.op); - return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), - static_cast(instance.laneCount)); + return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast(instance.laneCount)); } -ComputeGraph buildComputeGraph(Operation *entryOp) { +ComputeGraph buildComputeGraph(Operation* entryOp) { ComputeGraph graph; - for (Region ®ion : entryOp->getRegions()) { - for (Block &block : region) { - for (Operation &op : block) { + for (Region& region : entryOp->getRegions()) { + for (Block& block : region) { + for (Operation& op : block) { if (auto spatCompute = dyn_cast(&op)) { if (isUsedAsWeightOnly(spatCompute.getOperation())) continue; ComputeInstance instance {spatCompute.getOperation(), 0, 1}; size_t index = graph.nodes.size(); - graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index}); + graph.nodes.push_back( + {instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index}); graph.instanceToIndex[instance] = index; continue; } @@ -135,9 +135,21 @@ ComputeGraph buildComputeGraph(Operation *entryOp) { } llvm::SmallVector rawEdges; - for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) { + for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) { for (Value input : getComputeInstanceInputs(node.instance)) { - auto producerInstance = getComputeProducerInstance(input); + if (auto producerBatch = dyn_cast_or_null(input.getDefiningOp()); + producerBatch && producerBatch.getNumResults() != 0 && !isa(node.instance.op)) { + for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) { + auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane)); + if (producerIt == graph.instanceToIndex.end()) + continue; + rawEdges.push_back( + {producerIt->second, targetIndex, static_cast(getSizeInBytes(cast(input.getType())))}); + } + continue; + } + + auto producerInstance = getComputeProducerInstance(input, &node.instance); if (!producerInstance) continue; auto producerIt = graph.instanceToIndex.find(*producerInstance); @@ -152,7 +164,7 @@ ComputeGraph buildComputeGraph(Operation *entryOp) { graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end()); graph.successors.assign(graph.nodes.size(), {}); graph.predecessors.assign(graph.nodes.size(), {}); - for (const ComputeGraphEdge &edge : graph.edges) { + for (const ComputeGraphEdge& edge : graph.edges) { graph.successors[edge.source].push_back({edge.target, edge.transferCost}); graph.predecessors[edge.target].push_back({edge.source, edge.transferCost}); } @@ -160,7 +172,7 @@ ComputeGraph buildComputeGraph(Operation *entryOp) { return graph; } -bool verifyAcyclic(const ComputeGraph &graph) { +bool verifyAcyclic(const ComputeGraph& graph) { std::vector remainingParents(graph.nodes.size(), 0); std::queue readyNodes; for (size_t node = 0; node < graph.nodes.size(); ++node) { @@ -174,7 +186,7 @@ bool verifyAcyclic(const ComputeGraph &graph) { size_t node = readyNodes.front(); readyNodes.pop(); ++visited; - for (const auto &[child, weight] : graph.successors[node]) { + for (const auto& [child, weight] : graph.successors[node]) { (void) weight; assert(remainingParents[child] > 0 && "remaining parent count underflow"); if (--remainingParents[child] == 0) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp index d583249..698f689 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp @@ -1,6 +1,8 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include +#include #include "ComputeInstanceUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -18,48 +20,91 @@ size_t getSchedulingCpuBudget() { size_t getBatchChunkTargetCount(int32_t laneCount) { assert(laneCount > 0 && "laneCount must be positive"); - return std::min(static_cast(laneCount), std::max(1, getSchedulingCpuBudget())); + return static_cast(laneCount); } ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { - size_t totalLanes = batch.getLaneCount(); - size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); - size_t baseChunkSize = totalLanes / chunkCount; - size_t largeChunkCount = totalLanes % chunkCount; - - size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount); - size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0); - return {batch.getOperation(), static_cast(laneStart), static_cast(laneCount)}; + assert(chunkIndex < static_cast(batch.getLaneCount()) && "chunkIndex out of range"); + return {batch.getOperation(), static_cast(chunkIndex), 1}; } ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) { - size_t totalLanes = batch.getLaneCount(); - size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); - size_t baseChunkSize = totalLanes / chunkCount; - size_t largeChunkCount = totalLanes % chunkCount; - size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1); - - size_t chunkIndex = 0; - if (static_cast(lane) < largeChunkSpan) - chunkIndex = static_cast(lane) / (baseChunkSize + 1); - else - chunkIndex = largeChunkCount + (static_cast(lane) - largeChunkSpan) / baseChunkSize; - return getBatchChunkForIndex(batch, chunkIndex); + assert(lane < static_cast(batch.getLaneCount()) && "lane out of range"); + return {batch.getOperation(), lane, 1}; } -std::optional getProducerValueRef(Value value) { - Operation *op = value.getDefiningOp(); +static std::optional getConstantExtractLane(tensor::ExtractSliceOp extract) { + if (extract.getMixedOffsets().empty()) + return std::nullopt; + + OpFoldResult offset = extract.getMixedOffsets().front(); + if (Attribute attr = llvm::dyn_cast(offset)) { + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() < 0) + return std::nullopt; + return static_cast(intAttr.getInt()); + } + + Value offsetValue = llvm::cast(offset); + if (auto constantIndex = offsetValue.getDefiningOp()) { + if (constantIndex.value() < 0) + return std::nullopt; + return static_cast(constantIndex.value()); + } + + return std::nullopt; +} + +static std::optional getResultfulBatchProducerValueRef(SpatComputeBatch batch, + const ComputeInstance* consumerInstance) { + if (!consumerInstance) + return std::nullopt; + if (!isa(consumerInstance->op)) + return std::nullopt; + if (consumerInstance->laneStart + consumerInstance->laneCount > static_cast(batch.getLaneCount())) + return std::nullopt; + return ProducerValueRef { + {batch.getOperation(), consumerInstance->laneStart, consumerInstance->laneCount}, + 0 + }; +} + +std::optional getProducerValueRef(Value value, const ComputeInstance* consumerInstance) { + Operation* op = value.getDefiningOp(); if (!op) return std::nullopt; + while (auto extract = dyn_cast(op)) { + Value source = extract.getSource(); + auto batch = dyn_cast_or_null(source.getDefiningOp()); + if (batch && batch.getNumResults() != 0) { + if (std::optional lane = getConstantExtractLane(extract)) { + if (*lane >= static_cast(batch.getLaneCount())) + return std::nullopt; + return ProducerValueRef { + {batch.getOperation(), *lane, 1}, + 0 + }; + } + return getResultfulBatchProducerValueRef(batch, consumerInstance); + } + + value = source; + op = value.getDefiningOp(); + if (!op) + return std::nullopt; + } + if (auto compute = dyn_cast(op)) { return ProducerValueRef { ComputeInstance {compute.getOperation(), 0, 1}, - static_cast(cast(value).getResultNumber()) + static_cast(cast(value).getResultNumber()) }; } if (auto batch = dyn_cast(op)) { + if (batch.getNumResults() != 0) + return getResultfulBatchProducerValueRef(batch, consumerInstance); uint32_t lane = cast(value).getResultNumber(); ComputeInstance instance = getBatchChunkForLane(batch, lane); size_t resultIndex = lane - instance.laneStart; @@ -69,42 +114,60 @@ std::optional getProducerValueRef(Value value) { return std::nullopt; } -std::optional getComputeProducerInstance(Value value) { - if (std::optional producer = getProducerValueRef(value)) +std::optional getComputeProducerInstance(Value value, const ComputeInstance* consumerInstance) { + if (std::optional producer = getProducerValueRef(value, consumerInstance)) return producer->instance; return std::nullopt; } -llvm::SmallVector getComputeInstanceInputs(const ComputeInstance &instance) { +llvm::SmallVector getComputeInstanceInputs(const ComputeInstance& instance) { if (auto compute = dyn_cast(instance.op)) return llvm::SmallVector(compute.getInputs().begin(), compute.getInputs().end()); auto batch = cast(instance.op); + if (batch.getNumResults() != 0) + return llvm::SmallVector(batch.getInputs().begin(), batch.getInputs().end()); + + assert(batch.getInputs().size() % static_cast(batch.getLaneCount()) == 0 + && "resultless compute_batch inputs must be evenly partitioned by lane"); + size_t inputsPerLane = batch.getInputs().size() / static_cast(batch.getLaneCount()); llvm::SmallVector inputs; - inputs.reserve(instance.laneCount); - for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) - if (!batch.getInputs().empty()) - inputs.push_back(batch.getInputs()[lane]); + inputs.reserve(instance.laneCount * inputsPerLane); + for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) { + size_t firstInput = static_cast(lane) * inputsPerLane; + inputs.append(batch.getInputs().begin() + firstInput, batch.getInputs().begin() + firstInput + inputsPerLane); + } return inputs; } -llvm::SmallVector getComputeInstanceWeights(const ComputeInstance &instance) { +llvm::SmallVector getComputeInstanceWeights(const ComputeInstance& instance) { if (auto compute = dyn_cast(instance.op)) return llvm::SmallVector(compute.getWeights().begin(), compute.getWeights().end()); auto batch = cast(instance.op); + if (batch.getNumResults() != 0) + return llvm::SmallVector(batch.getWeights().begin(), batch.getWeights().end()); + + assert(batch.getWeights().size() % static_cast(batch.getLaneCount()) == 0 + && "resultless compute_batch weights must be evenly partitioned by lane"); + size_t weightsPerLane = batch.getWeights().size() / static_cast(batch.getLaneCount()); llvm::SmallVector weights; - weights.reserve(instance.laneCount); - for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) - weights.push_back(batch.getWeights()[lane]); + weights.reserve(instance.laneCount * weightsPerLane); + for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) { + size_t firstWeight = static_cast(lane) * weightsPerLane; + weights.append(batch.getWeights().begin() + firstWeight, batch.getWeights().begin() + firstWeight + weightsPerLane); + } return weights; } -llvm::SmallVector getComputeInstanceOutputValues(const ComputeInstance &instance) { +llvm::SmallVector getComputeInstanceOutputValues(const ComputeInstance& instance) { if (auto compute = dyn_cast(instance.op)) return llvm::SmallVector(compute.getResults().begin(), compute.getResults().end()); auto batch = cast(instance.op); + if (batch.getNumResults() != 0) + return llvm::SmallVector(batch.getResults().begin(), batch.getResults().end()); + llvm::SmallVector outputs; outputs.reserve(instance.laneCount); for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) @@ -113,14 +176,14 @@ llvm::SmallVector getComputeInstanceOutputValues(const ComputeInstance return outputs; } -llvm::SmallVector getComputeInstanceOutputTypes(const ComputeInstance &instance) { +llvm::SmallVector getComputeInstanceOutputTypes(const ComputeInstance& instance) { llvm::SmallVector outputTypes; for (Value output : getComputeInstanceOutputValues(instance)) outputTypes.push_back(output.getType()); return outputTypes; } -Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance) { +Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance) { if (auto compute = dyn_cast(instance.op)) return compute.getBody().front(); return cast(instance.op).getBody().front(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp index 91f3e39..90af9d7 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp @@ -26,8 +26,10 @@ size_t getBatchChunkTargetCount(int32_t laneCount); ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex); ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane); -std::optional getProducerValueRef(mlir::Value value); -std::optional getComputeProducerInstance(mlir::Value value); +std::optional getProducerValueRef(mlir::Value value, + const ComputeInstance *consumerInstance = nullptr); +std::optional getComputeProducerInstance(mlir::Value value, + const ComputeInstance *consumerInstance = nullptr); llvm::SmallVector getComputeInstanceInputs(const ComputeInstance &instance); llvm::SmallVector getComputeInstanceWeights(const ComputeInstance &instance); diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp index 0e7e89e..922ae87 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp @@ -268,24 +268,31 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern(dstByteOffset)), - rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); }); if (failed(status)) @@ -301,24 +308,31 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern(dstByteOffset)), - rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), rewriter.getI32IntegerAttr(static_cast(sliceBytes))); }); if (failed(status)) @@ -355,9 +369,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern(subviewOp.getType()).getElementType(); - auto resultMemRefType = - MemRefType::get(SmallVector(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType); + auto resultMemRefType = cast(subviewOp.getType()); auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape()); if (failed(foldedAttr)) return failure(); diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index 8c705ef..735d7de 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -23,11 +23,11 @@ namespace { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { if (isa(op)) - return operandIndex == 1; + return operandIndex == 3; if (isa(op)) return operandIndex == 1; if (isa(op)) - return operandIndex == 0; + return operandIndex == 2; return false; } @@ -39,7 +39,10 @@ static int64_t getValueSizeInBytes(Value value) { } template -static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) { +static void materializeHostConstantsInCore(CoreOpTy coreOp, + IRRewriter& rewriter, + OperationFolder& constantFolder, + bool& hasFailure) { DenseMap>> materializedValues; SmallVector ops; coreOp.getBody().front().walk([&](Operation* op) { @@ -48,6 +51,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter }); for (Operation* op : ops) { + if (auto loadOp = dyn_cast(op); loadOp && loadOp.getType().isIndex()) + continue; + for (OpOperand& operand : op->getOpOperands()) { Value originalValue = operand.get(); if (!isa(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber())) @@ -105,16 +111,17 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter .getOutput(); } else { - copiedValue = pim::PimMemCopyHostToDevOp::create( - rewriter, - op->getLoc(), - originalType, - deviceDst, - getGlobalOp.getResult(), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(static_cast(resolvedAddress->byteOffset)), - rewriter.getI32IntegerAttr(static_cast(totalBytes))) - .getOutput(); + copiedValue = + pim::PimMemCopyHostToDevOp::create( + rewriter, + op->getLoc(), + originalType, + getOrCreateHostIndexConstant(op, 0, constantFolder), + getOrCreateHostIndexConstant(op, static_cast(resolvedAddress->byteOffset), constantFolder), + deviceDst, + getGlobalOp.getResult(), + rewriter.getI32IntegerAttr(static_cast(totalBytes))) + .getOutput(); } cachedByType[originalType] = copiedValue; @@ -134,6 +141,7 @@ struct MaterializeHostConstantsPass : PassWrapper()) { @@ -141,10 +149,10 @@ struct MaterializeHostConstantsPass : PassWrapper()) - materializeHostConstantsInCore(coreOp, rewriter, hasFailure); + materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure); for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps()) - materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure); + materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure); SmallVector hostCompactOps; for (Operation& op : funcOp.getBody().front()) diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 46a9eed..057cfe7 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -8,6 +8,7 @@ #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -119,11 +120,27 @@ static bool isConstantGlobalView(Value value) { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { if (isa(op)) - return operandIndex == 1; + return operandIndex == 3; if (isa(op)) return operandIndex == 1; if (isa(op)) - return operandIndex == 0; + return operandIndex == 2; + return false; +} + +static bool isCoreWeightBlockArgument(Value value) { + auto blockArgument = dyn_cast(value); + if (!blockArgument) + return false; + + if (auto coreOp = dyn_cast(blockArgument.getOwner()->getParentOp())) + return static_cast(blockArgument.getArgNumber()) < coreOp.getWeights().size(); + + if (auto coreBatchOp = dyn_cast(blockArgument.getOwner()->getParentOp())) { + unsigned argNumber = static_cast(blockArgument.getArgNumber()); + return argNumber > 0 && argNumber <= coreBatchOp.getWeights().size(); + } + return false; } @@ -193,7 +210,9 @@ struct VerificationPass : PassWrapper> if (auto coreBatchOp = dyn_cast(&op)) { (void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics); - (void) verifyCoreOperands(coreBatchOp, diagnostics); + for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) + (void) withScalarCoreFromBatchLane( + coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); }); continue; } @@ -297,6 +316,9 @@ private: if (!isa(operand.getType())) continue; + if (isCoreWeightBlockArgument(operand)) + continue; + auto resolvedAddress = resolveContiguousAddress(operand, knowledge); if (failed(resolvedAddress)) { diagnostics.report(&op, [&](Operation* illegalOp) { @@ -327,6 +349,26 @@ private: hasFailure = true; } } + + if (auto storeOp = dyn_cast(op)) { + if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge)) + || failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) { + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen"); + }); + hasFailure = true; + } + } + + if (auto loadOp = dyn_cast(op)) { + if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge)) + || failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) { + diagnostics.report(&op, [](Operation* illegalOp) { + illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen"); + }); + hasFailure = true; + } + } return success(!hasFailure); }); } diff --git a/validation/validate.py b/validation/validate.py index 461e507..9bfb371 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -67,7 +67,7 @@ def main(): help="Core count to pass to Raptor. Required for PIM validation.") ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft", help="Scheduler used by the Spatial merge-compute-nodes pass.") - ap.add_argument("--command-timeout-seconds", type=float, default=6000000000000000.0, + ap.add_argument("--command-timeout-seconds", type=float, default=1000000.0, help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.") ap.add_argument("--clean", action="store_true", help="Remove generated validation artifacts under each model workspace and exit.")