From 2c1da813b59b340c58a9dc04a43de5f5526e7b37 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 22 May 2026 18:53:38 +0200 Subject: [PATCH] fix much stuff --- CMakeLists.txt | 8 +- src/PIM/Common/IR/WeightUtils.cpp | 11 +- .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 31 ++-- .../Conversion/ONNXToSpatial/PostPatterns.cpp | 71 ++++++-- .../BatchCoreLoweringPatterns.cpp | 47 +++-- .../SpatialToPim/ComputeLikeRegionUtils.cpp | 16 +- .../SpatialToPim/CoreLoweringPatterns.cpp | 40 +++-- .../GlobalTensorMaterialization.cpp | 8 +- src/PIM/Dialect/Spatial/Spatial.td | 24 ++- src/PIM/Dialect/Spatial/SpatialOps.cpp | 164 ++++++++++++++++-- src/PIM/Dialect/Spatial/SpatialOps.hpp | 3 + src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 68 +++++--- src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 54 ++++-- .../MaterializeMergeSchedule.cpp | 105 ++++++----- .../MergeComputeNodesPass.cpp | 28 ++- validation/operations/README.md | 2 +- .../{gemm.onnx => simple/gemm_simple.onnx} | Bin validation/operations/gen_tests.py | 13 ++ 18 files changed, 502 insertions(+), 191 deletions(-) rename validation/operations/gemm/{gemm.onnx => simple/gemm_simple.onnx} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b7a5d3..ef21f85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,13 +32,7 @@ function(raptor_write_external_cmake_shim shim_dir external_source_dir descripti \"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\" REALPATH ) -add_subdirectory( - \"\${raptor_external_source_dir}\" - \"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\" -) -if (DEFINED PIM_ENABLED) - set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE) -endif () +include(\"\${raptor_external_source_dir}/CMakeLists.txt\") " ) diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index fbeff3e..a206f3a 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -21,13 +21,15 @@ namespace { template bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { - mlir::Value weightArg = parentOp.getWeightArgument(weightIndex); + auto weightArg = parentOp.getWeightArgument(weightIndex); + if (!weightArg) + return false; bool found = false; parentOp.walk([&](mlir::Operation* op) { if (auto mvmOp = mlir::dyn_cast(op)) - found |= mvmOp.getWeight() == weightArg; + found |= mvmOp.getWeight() == *weightArg; else if (auto vmmOp = mlir::dyn_cast(op)) - found |= vmmOp.getWeight() == weightArg; + found |= vmmOp.getWeight() == *weightArg; }); return found; } @@ -38,7 +40,8 @@ void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref visited; auto walkWeight = [&](mlir::Value weight) { for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) { - if (parentOp.getWeightArgument(weightIndex) != weight) + auto weightArg = parentOp.getWeightArgument(weightIndex); + if (!weightArg || *weightArg != weight) continue; if (visited.insert(weightIndex).second) callback(parentOp->getOpOperand(weightIndex)); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 707b71e..3bd61c7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -422,12 +422,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, SmallVector vmmOutputs; vmmOutputs.reserve(aHSlices[coreId].size()); - for (auto aHSliceId : llvm::seq(0, aHSlices[coreId].size())) + for (auto aHSliceId : llvm::seq(0, aHSlices[coreId].size())) { + auto weightArg = computeOp.getWeightArgument(aHSliceId); + auto inputArg = computeOp.getInputArgument(aHSliceId); + if (!weightArg || !inputArg) + return failure(); vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, - computeOp.getWeightArgument(aHSliceId), - computeOp.getInputArgument(aHSliceId))); + *weightArg, + *inputArg)); + } if (vmmOutputs.empty()) { gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); return failure(); @@ -561,29 +566,31 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); rewriter.setInsertionPointToEnd(body); - Value lane = batchOp.getLaneArgument(); - Value weight = batchOp.getWeightArgument(0); - Value packedInput = batchOp.getInputArgument(0); - Value packedOutput = batchOp.getOutputArgument(0); + auto lane = batchOp.getLaneArgument(); + auto weight = batchOp.getWeightArgument(0); + auto packedInput = batchOp.getInputArgument(0); + auto packedOutput = batchOp.getOutputArgument(0); + if (!lane || !weight || !packedInput || !packedOutput) + return failure(); - SmallVector inputOffsets {lane, rewriter.getIndexAttr(0)}; + SmallVector inputOffsets {*lane, rewriter.getIndexAttr(0)}; SmallVector inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value row = - tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides) + tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides) .getResult(); - Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult(); + Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult(); Value laneResult = vmmResult; if (sharedBias) laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - SmallVector outputOffsets {lane, rewriter.getIndexAttr(0)}; + SmallVector outputOffsets {*lane, rewriter.getIndexAttr(0)}; SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))}; tensor::ParallelInsertSliceOp::create( - rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides); + rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides); rewriter.setInsertionPointAfter(batchOp); rewriter.replaceOp(gemmOp, batchOp.getResults()); diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp index 8a41f24..a212b6f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp @@ -27,13 +27,16 @@ static bool canPromoteInputBlockArgument(BlockArgument arg) { return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser); } +static bool canPromoteInputBlockArgument(std::optional arg) { + return arg && canPromoteInputBlockArgument(*arg); +} + static bool isDirectConstantValue(Value value) { return isa_and_nonnull(value.getDefiningOp()); } template static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { - Block& block = compute.getBody().front(); for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { if (!isWeightLikeComputeOperand(input)) continue; @@ -104,20 +107,30 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern(compute.getLaneCount())), newWeights, newInputs); + auto laneArg = compute.getLaneArgument(); + if (!laneArg) + return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument"); SmallVector newBlockArgTypes; SmallVector newBlockArgLocs; newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults()); newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults()); - newBlockArgTypes.push_back(compute.getLaneArgument().getType()); - newBlockArgLocs.push_back(compute.getLaneArgument().getLoc()); + newBlockArgTypes.push_back(laneArg->getType()); + newBlockArgLocs.push_back(laneArg->getLoc()); for (Value weight : newWeights) { newBlockArgTypes.push_back(weight.getType()); newBlockArgLocs.push_back(weight.getLoc()); @@ -197,8 +213,11 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePatterngetLoc()); } auto* newBlock = rewriter.createBlock( @@ -211,25 +230,41 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(0, compute.getNumResults())) { + auto outputArg = compute.getOutputArgument(resultIndex); + if (!outputArg) + return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite"); + mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex)); } - for (auto resultIndex : llvm::seq(0, compute.getNumResults())) - mapper.map(compute.getOutputArgument(resultIndex), - newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex)); for (Operation& op : oldBlock) rewriter.clone(op, mapper); diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 8cacfb2..6e93896 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -183,17 +183,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); - SmallVector hostOutputTensors; + SmallVector returnOperandIndices; if (computeBatchOp.getNumResults() != 0) { - hostOutputTensors.resize(computeBatchOp.getNumResults()); + returnOperandIndices.resize(computeBatchOp.getNumResults()); for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) { FailureOr returnOperandIndex = getDirectReturnOperandIndex(cast(result)); if (failed(returnOperandIndex)) return computeBatchOp.emitOpError( "resultful compute_batch lowering currently requires each result to be used directly by func.return"); - - hostOutputTensors[resultIndex] = outputTensors[*returnOperandIndex](rewriter, loc); - result.replaceAllUsesWith(hostOutputTensors[resultIndex]); + returnOperandIndices[resultIndex] = *returnOperandIndex; } } @@ -209,11 +207,20 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute IRMapping mapper; rewriter.setInsertionPointToStart(newBlock); - mapper.map(computeBatchOp.getLaneArgument(), coreBatchOp.getLaneArgument()); - for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) - mapper.map(computeBatchOp.getWeightArgument(weightIndex), coreBatchOp.getWeightArgument(weightIndex)); + auto oldLaneArg = computeBatchOp.getLaneArgument(); + if (!oldLaneArg) + return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering"); + mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument()); + for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) { + auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex); + if (!oldWeightArg) + return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering"); + mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex)); + } for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) { - BlockArgument oldArg = computeBatchOp.getInputArgument(inputIndex); + auto oldArg = computeBatchOp.getInputArgument(inputIndex); + if (!oldArg) + return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering"); BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex); auto newArgType = cast(newArg.getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); @@ -226,7 +233,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute rewriter.getI32IntegerAttr(0), getTensorSizeInBytesAttr(rewriter, newArg)) .getOutput(); - mapper.map(oldArg, copied); + mapper.map(*oldArg, copied); } auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { @@ -248,13 +255,25 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute return copied; }; + SmallVector hostOutputTensors(returnOperandIndices.size()); + auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value { + Value& hostOutputTensor = hostOutputTensors[resultIndex]; + if (hostOutputTensor) + return hostOutputTensor; + + hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc); + return hostOutputTensor; + }; + rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : oldBlock) { if (isa(op)) continue; if (auto parallelOp = dyn_cast(op)) { - unsigned firstOutputArg = computeBatchOp.getOutputArgument(0).getArgNumber(); + auto firstOutputArg = computeBatchOp.getOutputArgument(0); + if (!firstOutputArg) + return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering"); for (Operation& nestedOp : parallelOp.getRegion().front()) { auto insertSlice = dyn_cast(&nestedOp); if (!insertSlice) @@ -264,12 +283,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute if (!outputArg || outputArg.getOwner() != &oldBlock) return insertSlice.emitOpError("expected compute_batch output block argument destination"); - unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg; - if (resultIndex >= hostOutputTensors.size()) + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (resultIndex >= returnOperandIndices.size()) return insertSlice.emitOpError("result index out of range while lowering host batch output"); Value mappedSource = mapper.lookup(insertSlice.getSource()); - auto hostTarget = hostOutputTensors[resultIndex]; + Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc()); auto hostTargetType = cast(hostTarget.getType()); Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper); Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult(); diff --git a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp index 655aef3..1860a9d 100644 --- a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp +++ b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp @@ -1,3 +1,5 @@ +#include + #include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -29,9 +31,17 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter, unsigned inputIndex, Value replacement) { Block& body = owner->getRegion(0).front(); - BlockArgument bodyArgument = isa(owner) - ? cast(owner).getInputArgument(inputIndex) - : cast(owner).getInputArgument(inputIndex); + BlockArgument bodyArgument; + if (auto compute = dyn_cast(owner)) { + auto computeArg = compute.getInputArgument(inputIndex); + assert(computeArg && "expected compute input block argument"); + bodyArgument = *computeArg; + } + else { + auto batchArg = cast(owner).getInputArgument(inputIndex); + assert(batchArg && "expected compute_batch input block argument"); + bodyArgument = *batchArg; + } unsigned bodyArgIndex = bodyArgument.getArgNumber(); rewriter.startOpModification(owner); diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index 171dde5..ce204ca 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -131,8 +131,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute rewriter.setInsertionPoint(computeOp); IRMapping mapping; - for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) - mapping.map(computeOp.getWeightArgument(weightIndex), weight); + for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) { + auto weightArg = computeOp.getWeightArgument(weightIndex); + if (!weightArg) + return false; + mapping.map(*weightArg, weight); + } for (Operation& op : block.without_terminator()) { cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder); Operation* clonedOp = rewriter.clone(op, mapping); @@ -164,31 +168,33 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp auto yieldOp = cast(block.getTerminator()); for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { - BlockArgument blockArg = computeOp.getInputArgument(inputIndex); + auto blockArg = computeOp.getInputArgument(inputIndex); + if (!blockArg) + return computeOp.emitOpError("expected compute input block arguments during lowering"); auto receiveOp = dyn_cast_or_null(input.getDefiningOp()); - if (receiveOp && !blockArg.use_empty()) { - rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); - auto outputType = cast(blockArg.getType()); + if (receiveOp && !blockArg->use_empty()) { + rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg)); + auto outputType = cast(blockArg->getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg); Value received = PimReceiveOp::create( rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId()) .getOutput(); - blockArg.replaceAllUsesWith(received); + blockArg->replaceAllUsesWith(received); markOpToRemove(receiveOp); continue; } auto receiveTensorOp = dyn_cast_or_null(input.getDefiningOp()); - if (receiveTensorOp && !blockArg.use_empty()) { + if (receiveTensorOp && !blockArg->use_empty()) { FailureOr> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds()); if (failed(sourceCoreIds)) return receiveTensorOp.emitOpError("expected constant sourceCoreIds"); for (int32_t& sourceCoreId : *sourceCoreIds) sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId); - rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); - auto outputType = cast(blockArg.getType()); + rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg)); + auto outputType = cast(blockArg->getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType); Value received = PimReceiveTensorOp::create(rewriter, receiveTensorOp.getLoc(), @@ -196,7 +202,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds)) .getOutput(); - blockArg.replaceAllUsesWith(received); + blockArg->replaceAllUsesWith(received); markOpToRemove(receiveTensorOp); } } @@ -238,12 +244,14 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp rewriter.setInsertionPointToStart(&block); auto& coreOpBlocks = coreOp.getBody().getBlocks(); for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { - BlockArgument blockArg = computeOp.getInputArgument(inputIndex); - if (blockArg.use_empty()) + auto blockArg = computeOp.getInputArgument(inputIndex); + if (!blockArg) + return computeOp.emitOpError("expected compute input block arguments during input materialization"); + if (blockArg->use_empty()) continue; if (auto constantOp = input.getDefiningOp()) { - blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder)); + blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder)); continue; } @@ -261,7 +269,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp input, getTensorSizeInBytesAttr(rewriter, input)) .getOutput(); - blockArg.replaceAllUsesWith(copied); + blockArg->replaceAllUsesWith(copied); } if (!computeOp.getInputs().empty()) block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size()); diff --git a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp index c406f43..4c0e4d5 100644 --- a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp +++ b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp @@ -77,8 +77,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatternuse_empty()) continue; rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); @@ -95,8 +97,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatternuse_empty()) continue; rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 93cd4b6..fff7641 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -43,8 +43,14 @@ def SpatCompute : SpatOp<"compute", let regions = (region SizedRegion<1>:$body); let extraClassDeclaration = [{ - ::mlir::BlockArgument getWeightArgument(unsigned idx); - ::mlir::BlockArgument getInputArgument(unsigned idx); + std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); + std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); + std::optional> + insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); + std::optional> + insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); + ::mlir::FailureOr> + insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; let hasVerifier = 1; @@ -70,10 +76,16 @@ def SpatComputeBatch : SpatOp<"compute_batch", let regions = (region SizedRegion<1>:$body); let extraClassDeclaration = [{ - ::mlir::BlockArgument getLaneArgument(); - ::mlir::BlockArgument getWeightArgument(unsigned idx); - ::mlir::BlockArgument getInputArgument(unsigned idx); - ::mlir::BlockArgument getOutputArgument(unsigned idx); + std::optional<::mlir::BlockArgument> getLaneArgument(); + std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx); + std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx); + std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx); + std::optional> + insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); + std::optional> + insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); + ::mlir::FailureOr> + insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; let hasVerifier = 1; diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 41d3146..4c650f3 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -6,11 +6,81 @@ using namespace mlir; namespace onnx_mlir { namespace spatial { +namespace { -BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); } +std::optional getBatchBodyArgument(Region& body, unsigned argIdx) { + if (body.empty()) + return std::nullopt; -BlockArgument SpatCompute::getInputArgument(unsigned idx) { - return getBody().front().getArgument(getWeights().size() + idx); + Block& block = body.front(); + if (argIdx >= block.getNumArguments()) + return std::nullopt; + return block.getArgument(argIdx); +} + +std::optional insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) { + if (body.empty()) + return std::nullopt; + return body.insertArgument(argIdx, type, loc); +} + +void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) { + if (auto compute = dyn_cast(op)) { + compute.getProperties().setOperandSegmentSizes({weightCount, inputCount}); + return; + } + cast(op).getProperties().setOperandSegmentSizes({weightCount, inputCount}); +} + +} // namespace + +std::optional SpatCompute::getWeightArgument(unsigned idx) { return getBatchBodyArgument(getBody(), idx); } + +std::optional SpatCompute::getInputArgument(unsigned idx) { + return getBatchBodyArgument(getBody(), getWeights().size() + idx); +} + +std::optional> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) { + unsigned weightCount = getWeights().size(); + unsigned inputCount = getInputs().size(); + getOperation()->insertOperands(idx, ValueRange {weight}); + setComputeOperandSegmentSizes( + getOperation(), static_cast(weightCount + 1), static_cast(inputCount)); + auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc); + if (!blockArg) + return std::nullopt; + return std::make_tuple(getOperation()->getOperand(idx), *blockArg); +} + +std::optional> SpatCompute::insertInput(unsigned idx, Value input, Location loc) { + unsigned weightCount = getWeights().size(); + unsigned inputCount = getInputs().size(); + getOperation()->insertOperands(weightCount + idx, ValueRange {input}); + setComputeOperandSegmentSizes( + getOperation(), static_cast(weightCount), static_cast(inputCount + 1)); + auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc); + if (!blockArg) + return std::nullopt; + return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); +} + +FailureOr> +SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { + if (idx > getNumResults()) + return failure(); + + rewriter.setInsertionPoint(getOperation()); + SmallVector resultTypes(getResultTypes().begin(), getResultTypes().end()); + resultTypes.insert(resultTypes.begin() + idx, type); + auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs()); + newCompute->setAttrs((*this)->getAttrs()); + setComputeOperandSegmentSizes( + newCompute.getOperation(), static_cast(newCompute.getWeights().size()), static_cast(newCompute.getInputs().size())); + rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end()); + for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx) + getResult(oldResultIdx).replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); + rewriter.eraseOp(getOperation()); + return std::make_tuple(cast(newCompute.getResult(idx)), newCompute); } void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { @@ -18,42 +88,102 @@ void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn s return; for (unsigned index = 0; index < getWeights().size(); ++index) - setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); + if (auto weightArg = getWeightArgument(index)) + setNameFn(*weightArg, ("w" + std::to_string(index)).c_str()); for (unsigned index = 0; index < getInputs().size(); ++index) - setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); + if (auto inputArg = getInputArgument(index)) + setNameFn(*inputArg, ("in" + std::to_string(index)).c_str()); } -BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); } +std::optional SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); } -BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); } - -BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) { - return getBody().front().getArgument(1 + getWeights().size() + idx); +std::optional SpatComputeBatch::getWeightArgument(unsigned idx) { + return getBatchBodyArgument(getBody(), 1 + idx); } -BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) { - return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx); +std::optional SpatComputeBatch::getInputArgument(unsigned idx) { + return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx); +} + +std::optional SpatComputeBatch::getOutputArgument(unsigned idx) { + return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx); +} + +std::optional> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { + unsigned weightCount = getWeights().size(); + unsigned inputCount = getInputs().size(); + getOperation()->insertOperands(idx, ValueRange {weight}); + setComputeOperandSegmentSizes( + getOperation(), static_cast(weightCount + 1), static_cast(inputCount)); + auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc); + if (!blockArg) + return std::nullopt; + return std::make_tuple(getOperation()->getOperand(idx), *blockArg); +} + +std::optional> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) { + unsigned weightCount = getWeights().size(); + unsigned inputCount = getInputs().size(); + getOperation()->insertOperands(weightCount + idx, ValueRange {input}); + setComputeOperandSegmentSizes( + getOperation(), static_cast(weightCount), static_cast(inputCount + 1)); + auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc); + if (!blockArg) + return std::nullopt; + return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); +} + +FailureOr> +SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { + if (idx > getNumResults()) + return failure(); + + rewriter.setInsertionPoint(getOperation()); + SmallVector resultTypes(getResultTypes().begin(), getResultTypes().end()); + resultTypes.insert(resultTypes.begin() + idx, type); + auto newBatch = + SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs()); + newBatch->setAttrs((*this)->getAttrs()); + setComputeOperandSegmentSizes( + newBatch.getOperation(), static_cast(newBatch.getWeights().size()), static_cast(newBatch.getInputs().size())); + rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end()); + if (newBatch.getBody().empty()) { + rewriter.eraseOp(newBatch); + return failure(); + } + auto blockArg = newBatch.getBody().front().insertArgument( + 1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc); + for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx) + getResult(oldResultIdx).replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); + rewriter.eraseOp(getOperation()); + return std::make_tuple(cast(newBatch.getResult(idx)), blockArg, newBatch); } void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) { if (region.empty()) return; - setNameFn(getLaneArgument(), "lane"); + if (auto laneArg = getLaneArgument()) + setNameFn(*laneArg, "lane"); for (unsigned index = 0; index < getWeights().size(); ++index) - setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str()); + if (auto weightArg = getWeightArgument(index)) + setNameFn(*weightArg, ("w" + std::to_string(index)).c_str()); for (unsigned index = 0; index < getInputs().size(); ++index) - setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str()); + if (auto inputArg = getInputArgument(index)) + setNameFn(*inputArg, ("in" + std::to_string(index)).c_str()); for (unsigned index = 0; index < getNumResults(); ++index) { + auto outputArg = getOutputArgument(index); + if (!outputArg) + continue; if (index == 0) { - setNameFn(getOutputArgument(index), "out"); + setNameFn(*outputArg, "out"); continue; } - setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str()); + setNameFn(*outputArg, ("out" + std::to_string(index)).c_str()); } } diff --git a/src/PIM/Dialect/Spatial/SpatialOps.hpp b/src/PIM/Dialect/Spatial/SpatialOps.hpp index ce89ef3..89069b5 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.hpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.hpp @@ -5,12 +5,15 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include +#include #include +#include /// Include the auto-generated header files containing the declarations #include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc" diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index d3e49ce..8751dfa 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -218,17 +218,26 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { } void SpatCompute::print(OpAsmPrinter& printer) { - printer << " "; SmallVector weightArgs; weightArgs.reserve(getWeights().size()); - for (unsigned index = 0; index < getWeights().size(); ++index) - weightArgs.push_back(getWeightArgument(index)); - printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); - printer << " "; + for (unsigned index = 0; index < getWeights().size(); ++index) { + auto weightArg = getWeightArgument(index); + if (!weightArg) + return printer.printGenericOp(getOperation(), /*printOpName=*/false); + weightArgs.push_back(*weightArg); + } SmallVector inputArgs; inputArgs.reserve(getInputs().size()); - for (unsigned index = 0; index < getInputs().size(); ++index) - inputArgs.push_back(getInputArgument(index)); + for (unsigned index = 0; index < getInputs().size(); ++index) { + auto inputArg = getInputArgument(index); + if (!inputArg) + return printer.printGenericOp(getOperation(), /*printOpName=*/false); + inputArgs.push_back(*inputArg); + } + + printer << " "; + printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); + printer << " "; printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) @@ -309,29 +318,48 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { } void SpatComputeBatch::print(OpAsmPrinter& printer) { + auto laneArg = getLaneArgument(); + SmallVector weightArgs; + weightArgs.reserve(getWeights().size()); + for (unsigned index = 0; index < getWeights().size(); ++index) { + auto weightArg = getWeightArgument(index); + if (!weightArg) + return printer.printGenericOp(getOperation(), /*printOpName=*/false); + weightArgs.push_back(*weightArg); + } + SmallVector inputArgs; + inputArgs.reserve(getInputs().size()); + for (unsigned index = 0; index < getInputs().size(); ++index) { + auto inputArg = getInputArgument(index); + if (!inputArg) + return printer.printGenericOp(getOperation(), /*printOpName=*/false); + inputArgs.push_back(*inputArg); + } + + SmallVector outputArgs; + if (!laneArg) + return printer.printGenericOp(getOperation(), /*printOpName=*/false); + if (getNumResults() != 0) { + outputArgs.reserve(getNumResults()); + for (unsigned index = 0; index < getNumResults(); ++index) { + auto outputArg = getOutputArgument(index); + if (!outputArg) + return printer.printGenericOp(getOperation(), /*printOpName=*/false); + outputArgs.push_back(*outputArg); + } + } + printer << " "; - printer.printOperand(getLaneArgument()); + printer.printOperand(*laneArg); printer << " = 0 to " << getLaneCount(); printer << " "; - SmallVector weightArgs; - weightArgs.reserve(getWeights().size()); - for (unsigned index = 0; index < getWeights().size(); ++index) - weightArgs.push_back(getWeightArgument(index)); printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; - SmallVector inputArgs; - inputArgs.reserve(getInputs().size()); - for (unsigned index = 0; index < getInputs().size(); ++index) - inputArgs.push_back(getInputArgument(index)); printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); if (getNumResults() != 0) { printer << " shared_outs"; - SmallVector outputArgs; - outputArgs.reserve(getNumResults()); - for (unsigned index = 0; index < getNumResults(); ++index) - outputArgs.push_back(getOutputArgument(index)); printBlockArgumentList(printer, outputArgs); } diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index e753e41..86594a1 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -107,8 +107,11 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) { return false; unsigned argNumber = blockArg.getArgNumber(); - unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber(); - return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults(); + auto firstOutputArg = batchOp.getOutputArgument(0); + if (!firstOutputArg) + return false; + unsigned firstOutputArgNumber = firstOutputArg->getArgNumber(); + return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults(); } static bool isConstantIndexLike(Value value) { @@ -293,10 +296,12 @@ static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) { return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel"); } - BlockArgument laneArg = batchOp.getLaneArgument(); + auto laneArg = batchOp.getLaneArgument(); + if (!laneArg) + return batchOp.emitError("compute_batch body must have a lane block argument"); for (auto& bodyOp : block) { if (auto extractSlice = dyn_cast(&bodyOp)) - if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice"))) + if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice"))) return failure(); } return success(); @@ -457,12 +462,16 @@ LogicalResult SpatCompute::verify() { if (block.getNumArguments() != expectedArgCount) return emitError("compute body must have weight and input block arguments"); - for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) - if (getWeightArgument(weightIndex).getType() != weight.getType()) + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + auto blockArg = getWeightArgument(weightIndex); + if (!blockArg || blockArg->getType() != weight.getType()) return emitError("compute weight block argument types must match weight operand types exactly"); - for (auto [inputIndex, input] : llvm::enumerate(getInputs())) - if (getInputArgument(inputIndex).getType() != input.getType()) + } + for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { + auto blockArg = getInputArgument(inputIndex); + if (!blockArg || blockArg->getType() != input.getType()) return emitError("compute input block argument types must match input operand types exactly"); + } if (block.mightHaveTerminator()) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); @@ -497,7 +506,7 @@ LogicalResult SpatCompute::verify() { } for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex) - if (getInputArgument(inputIndex).use_empty()) + if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty()) return emitError("ComputeOp block argument is not used"); if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute"))) return failure(); @@ -574,23 +583,28 @@ LogicalResult SpatComputeBatch::verify() { } Block& block = getBody().front(); + if (block.getNumArguments() == 0) + return emitError("compute_batch body must have exactly one lane block argument"); unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults(); if (block.getNumArguments() != expectedArgCount) - return emitError("compute_batch body must have lane, weight, input, and output block arguments"); - if (!getLaneArgument().getType().isIndex()) + return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results"); + auto laneArg = getLaneArgument(); + if (!laneArg || !laneArg->getType().isIndex()) return emitError("compute_batch first block argument must have index type"); - for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) - if (getWeightArgument(weightIndex).getType() != weight.getType()) + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + auto blockArg = getWeightArgument(weightIndex); + if (!blockArg || blockArg->getType() != weight.getType()) return emitError("compute_batch weight block argument types must match weight operand types exactly"); + } for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { - BlockArgument blockArg = getInputArgument(inputIndex); - if (blockArg.getType() != input.getType()) + auto blockArg = getInputArgument(inputIndex); + if (!blockArg || blockArg->getType() != input.getType()) return emitError("compute_batch input block argument types must match input operand types exactly"); } for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) { - BlockArgument blockArg = getOutputArgument(resultIndex); - if (blockArg.getType() != resultType) + auto blockArg = getOutputArgument(resultIndex); + if (!blockArg || blockArg->getType() != resultType) return emitError("compute_batch output block argument types must match result types exactly"); } @@ -608,13 +622,15 @@ LogicalResult SpatInParallelOp::verify() { if (batchOp.getNumResults() == 0) return emitOpError("requires a resultful spat.compute_batch parent"); - BlockArgument laneArg = batchOp.getLaneArgument(); + auto laneArg = batchOp.getLaneArgument(); + if (!laneArg) + return emitOpError("expected compute_batch lane block argument"); for (Operation& op : getRegion().front().getOperations()) { auto insertSliceOp = dyn_cast(&op); if (!insertSliceOp) return emitOpError("expected only tensor.parallel_insert_slice ops"); - if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice"))) + if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice"))) return failure(); MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 54da355..bdc8ca1 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -432,15 +432,6 @@ LogicalResult collectHostOutputs(MaterializerState& state) { return success(); } -void setOperandSegmentSizes(Operation* op, int weightCount, int inputCount) { - if (auto compute = dyn_cast(op)) { - compute.getProperties().setOperandSegmentSizes({weightCount, inputCount}); - return; - } - auto batch = cast(op); - batch.getProperties().setOperandSegmentSizes({weightCount, inputCount}); -} - void createEmptyMaterializedOps(MaterializerState& state) { Location loc = state.func.getLoc(); Block& funcBlock = state.func.getBody().front(); @@ -529,19 +520,17 @@ BlockArgument appendWeight(MaterializerState& state, MaterializedClass& material materializedClass.weights.push_back(weight); if (auto compute = dyn_cast(materializedClass.op)) { - compute.getWeightsMutable().append(ValueRange(weight)); - setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size()); - BlockArgument arg = materializedClass.body->insertArgument(weightIndex, weight.getType(), weight.getLoc()); - materializedClass.weightArgs[weight] = arg; - return arg; + auto arg = compute.insertWeight(weightIndex, weight, weight.getLoc()); + assert(arg && "expected compute body while inserting a weight"); + materializedClass.weightArgs[weight] = std::get<1>(*arg); + return std::get<1>(*arg); } auto batch = cast(materializedClass.op); - batch.getWeightsMutable().append(ValueRange(weight)); - setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size()); - BlockArgument arg = materializedClass.body->insertArgument(1 + weightIndex, weight.getType(), weight.getLoc()); - materializedClass.weightArgs[weight] = arg; - return arg; + auto arg = batch.insertWeight(weightIndex, weight, weight.getLoc()); + assert(arg && "expected compute_batch body while inserting a weight argument"); + materializedClass.weightArgs[weight] = std::get<1>(*arg); + return std::get<1>(*arg); } BlockArgument appendInput(MaterializerState& state, MaterializedClass& materializedClass, Value input) { @@ -551,17 +540,16 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali materializedClass.inputs.push_back(input); if (auto compute = dyn_cast(materializedClass.op)) { - compute.getInputsMutable().append(ValueRange(input)); - BlockArgument arg = materializedClass.body->addArgument(input.getType(), input.getLoc()); - materializedClass.inputArgs[input] = arg; + auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); + assert(arg && "expected compute body while inserting an input"); + materializedClass.inputArgs[input] = std::get<1>(*arg); + return std::get<1>(*arg); } - else { - cast(materializedClass.op).getInputsMutable().append(ValueRange(input)); - setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size()); - BlockArgument arg = materializedClass.body->insertArgument( - materializedClass.body->getNumArguments() - 1, input.getType(), input.getLoc()); - materializedClass.inputArgs[input] = arg; - return arg; + if (auto compute = dyn_cast(materializedClass.op)) { + auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); + assert(arg && "expected compute_batch body while inserting an input argument"); + materializedClass.inputArgs[input] = std::get<1>(*arg); + return std::get<1>(*arg); } llvm_unreachable("Cannot reach here"); } @@ -608,6 +596,8 @@ Value createOriginalLaneValue(MaterializerState& state, return createIndexConstant(state, materializedClass.op, peers.front().laneStart); auto batch = cast(materializedClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected materialized compute_batch lane argument"); bool identity = true; for (auto [lane, peer] : llvm::enumerate(peers)) { if (peer.laneCount != 1 || peer.laneStart != lane) { @@ -616,7 +606,7 @@ Value createOriginalLaneValue(MaterializerState& state, } } if (identity) - return batch.getLaneArgument(); + return *laneArg; bool affineWithBase = true; int64_t base = static_cast(peers.front().laneStart); @@ -628,9 +618,9 @@ Value createOriginalLaneValue(MaterializerState& state, } if (affineWithBase) { if (base == 0) - return batch.getLaneArgument(); + return *laneArg; Value baseValue = createIndexConstant(state, materializedClass.op, base); - return arith::AddIOp::create(state.rewriter, loc, batch.getLaneArgument(), baseValue).getResult(); + return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult(); } SmallVector laneValues; @@ -641,7 +631,7 @@ Value createOriginalLaneValue(MaterializerState& state, auto tableType = RankedTensorType::get({static_cast(peers.size())}, state.rewriter.getIndexType()); auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues); Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult(); - return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {batch.getLaneArgument()}).getResult(); + return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); } bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) { @@ -838,7 +828,10 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val offsets.reserve(payloadType.getRank()); sizes.reserve(payloadType.getRank()); strides.reserve(payloadType.getRank()); - offsets.push_back(batch.getLaneArgument()); + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("expected compute_batch lane block argument while materializing batch output"); + offsets.push_back(*laneArg); sizes.push_back(state.rewriter.getIndexAttr(1)); strides.push_back(state.rewriter.getIndexAttr(1)); for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { @@ -847,8 +840,11 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val strides.push_back(state.rewriter.getIndexAttr(1)); } - tensor::ParallelInsertSliceOp::create( - state.rewriter, payload.getLoc(), payload, batch.getOutputArgument(resultIndex), offsets, sizes, strides); + auto outputArg = batch.getOutputArgument(resultIndex); + if (!outputArg) + return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); + + tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides); return success(); } @@ -1136,14 +1132,20 @@ void mapWeights(MaterializerState& state, IRMapping& mapper) { Operation* op = instance.op; if (auto compute = dyn_cast(op)) { - for (auto [index, weight] : llvm::enumerate(compute.getWeights())) - mapper.map(compute.getWeightArgument(index), appendWeight(state, targetClass, weight)); + for (auto [index, weight] : llvm::enumerate(compute.getWeights())) { + auto weightArg = compute.getWeightArgument(index); + assert(weightArg && "expected compute weight block argument"); + mapper.map(*weightArg, appendWeight(state, targetClass, weight)); + } return; } auto batch = cast(op); - for (auto [index, weight] : llvm::enumerate(batch.getWeights())) - mapper.map(batch.getWeightArgument(index), appendWeight(state, targetClass, weight)); + for (auto [index, weight] : llvm::enumerate(batch.getWeights())) { + auto weightArg = batch.getWeightArgument(index); + assert(weightArg && "expected compute_batch weight block argument"); + mapper.map(*weightArg, appendWeight(state, targetClass, weight)); + } } LogicalResult mapInputs(MaterializerState& state, @@ -1156,7 +1158,10 @@ LogicalResult mapInputs(MaterializerState& state, FailureOr mapped = resolveInputValue(state, targetClass, input, instance); if (failed(mapped)) return compute.emitOpError("failed to resolve materialized compute input"); - mapper.map(compute.getInputArgument(index), *mapped); + auto inputArg = compute.getInputArgument(index); + if (!inputArg) + return compute.emitOpError("expected compute input block argument while materializing inputs"); + mapper.map(*inputArg, *mapped); } return success(); } @@ -1166,7 +1171,10 @@ LogicalResult mapInputs(MaterializerState& state, FailureOr mapped = resolveInputValue(state, targetClass, input, instance); if (failed(mapped)) return batch.emitOpError("failed to resolve materialized compute_batch input"); - mapper.map(batch.getInputArgument(index), *mapped); + auto inputArg = batch.getInputArgument(index); + if (!inputArg) + return batch.emitOpError("expected compute_batch input block argument while materializing inputs"); + mapper.map(*inputArg, *mapped); } return success(); } @@ -1186,8 +1194,10 @@ SmallVector collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) continue; - unsigned firstOutputArg = batch.getOutputArgument(0).getArgNumber(); - unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg; + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return outputs; + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); if (resultIndex >= outputs.size()) continue; outputs[resultIndex] = mapper.lookupOrDefault(insert.getSource()); @@ -1217,7 +1227,12 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra return failure(); } } - mapper.map(batch.getLaneArgument(), createOriginalLaneValue(state, targetClass, peers, loc)); + auto laneArg = batch.getLaneArgument(); + if (!laneArg) { + sourceOp->emitError("expected source compute_batch lane block argument"); + return failure(); + } + mapper.map(*laneArg, createOriginalLaneValue(state, targetClass, peers, loc)); } mapWeights(state, targetClass, instance, mapper); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 13b8d72..e9f8f41 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -223,18 +223,32 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) { newBody->addArgument(input.getType(), loc); IRMapping mapper; - for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) - mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex)); - for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) - mapper.map(compute.getInputArgument(inputIndex), newCompute.getInputArgument(inputIndex)); - for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) - mapper.map(child.getWeightArgument(oldIndex), newCompute.getWeightArgument(childWeightToNewIndex[oldIndex])); + for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) { + auto oldWeightArg = compute.getWeightArgument(weightIndex); + auto newWeightArg = newCompute.getWeightArgument(weightIndex); + assert(oldWeightArg && newWeightArg && "expected compute weight block arguments"); + mapper.map(*oldWeightArg, *newWeightArg); + } + for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) { + auto oldInputArg = compute.getInputArgument(inputIndex); + auto newInputArg = newCompute.getInputArgument(inputIndex); + assert(oldInputArg && newInputArg && "expected compute input block arguments"); + mapper.map(*oldInputArg, *newInputArg); + } + for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) { + auto oldWeightArg = child.getWeightArgument(oldIndex); + auto newWeightArg = newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]); + assert(oldWeightArg && newWeightArg && "expected child compute weight block arguments"); + mapper.map(*oldWeightArg, *newWeightArg); + } rewriter.setInsertionPointToEnd(newBody); auto computeYield = cast(compute.getBody().front().getTerminator()); for (Operation& op : compute.getBody().front().without_terminator()) rewriter.clone(op, mapper); - mapper.map(child.getInputArgument(childInputIndex), mapper.lookupOrDefault(computeYield.getOperand(usedResult))); + auto childInputArg = child.getInputArgument(childInputIndex); + assert(childInputArg && "expected child compute input block argument"); + mapper.map(*childInputArg, mapper.lookupOrDefault(computeYield.getOperand(usedResult))); rewriter.setInsertionPointToEnd(newBody); for (auto& op : child.getBody().front()) diff --git a/validation/operations/README.md b/validation/operations/README.md index 2f5b977..2cce699 100644 --- a/validation/operations/README.md +++ b/validation/operations/README.md @@ -30,7 +30,7 @@ python3 validation/operations/gen_tests.py | Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes | |---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------| -| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights | +| Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights | | Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N | | With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector | | transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight | diff --git a/validation/operations/gemm/gemm.onnx b/validation/operations/gemm/simple/gemm_simple.onnx similarity index 100% rename from validation/operations/gemm/gemm.onnx rename to validation/operations/gemm/simple/gemm_simple.onnx diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 53e5821..6ae2af4 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -185,6 +185,18 @@ def conv_depthwise_grouped(): # GEMM tests # --------------------------------------------------------------------------- +def gemm_simple(): + """Simple GEMM with square weights: [10, 132] @ [132, 132].""" + B, K, N = 10, 132, 132 + W = numpy_helper.from_array(np.random.default_rng(41).uniform(-1, 1, (K, N)).astype(np.float32), name="W") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + node = helper.make_node("Gemm", ["A", "W"], ["Y"]) + graph = helper.make_graph([node], "gemm_simple", [A], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/simple", "gemm_simple.onnx") + + def gemm_non_square(): """GEMM with non-square weight matrix: [B, K] @ [K, N], K != N.""" B, K, N = 4, 128, 64 @@ -823,6 +835,7 @@ def div_after_gemm(): if __name__ == "__main__": print("Generating GEMM tests:") + gemm_simple() gemm_non_square() gemm_with_bias() gemm_transB()