diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index b296d29..a168cf9 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -4,6 +4,7 @@ #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" namespace onnx_mlir { @@ -227,7 +228,7 @@ llvm::FailureOr resolveContiguousAddressImpl(mlir::Va continue; } - if (mlir::isa(definingOp)) + if (mlir::isa(definingOp)) return ResolvedContiguousAddress {value, byteOffset}; return mlir::failure(); diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index 64104ba..63ebcdf 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -54,7 +54,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) { if (!computeOp || operandIndex >= computeOp.getWeights().size()) return false; - return hasMvmVmmWeightUse(computeOp, operandIndex); + return hasMvmVmmWeightUse(computeOp, operandIndex); } bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index fed460b..49a806f 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -97,6 +97,11 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { if (!allocOp->getParentOfType()) gatherMemEntry(allocOp.getResult()); }); + funcOp.walk([&](pim::PimEmptyManyOp emptyManyOp) { + if (!emptyManyOp->getParentOfType() && !emptyManyOp->getParentOfType()) + for (mlir::Value output : emptyManyOp.getOutputs()) + gatherMemEntry(output); + }); allocateGatheredMemory(); @@ -106,6 +111,10 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { void PimMemory::allocateCore(Operation* op) { op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); }); + op->walk([&](pim::PimEmptyManyOp emptyManyOp) { + for (mlir::Value output : emptyManyOp.getOutputs()) + gatherMemEntry(output); + }); allocateGatheredMemory(); } @@ -957,6 +966,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); else if (auto getGlobalOp = dyn_cast(op)) coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); + else if (isa(op)) + return success(); else { op.emitError("Unsupported codegen for this operation"); op.dump(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 52dd163..77052ba 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -381,7 +381,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, vmmOutputs.reserve(aHSlicesArgs.size()); for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) vmmOutputs.push_back( - spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); + spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); if (vmmOutputs.empty()) { gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); return failure(); @@ -527,7 +527,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, &batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector(1, loc)); rewriter.setInsertionPointToEnd(body); - Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult(); + Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult(); Value laneResult = vmmResult; if (sharedBias) laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 6788872..20891fb 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -95,7 +95,7 @@ bool hasLaterUserInBlock(mlir::Value value, Operation* operation) { return false; } -mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) { +mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) { assert("Only support operations with a single result" && operation->getNumResults() == 1); mlir::Value result = operation->getResult(0); auto resultType = result.getType(); diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index e006189..f0a09e9 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -41,7 +41,7 @@ mlir::Operation* getEarliestUserWithinBlock(mlir::Value value); mlir::SmallVector getOpOperandsSortedByUses(mlir::Operation* operation); -mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation); +mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation); inline mlir::tensor::EmptyOp createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) { diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index a0fbce5..4b73be5 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -16,13 +16,13 @@ def onnxToPimTranspose : Pat< >; def spatToPimVMM : Pat< - (SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector), + (SpatVMMOp:$srcOpRes $weightIndex, $vector), (PimVMMOp $weightIndex, $vector, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; def spatToPimMVM : Pat< - (SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector), + (SpatMVMOp:$srcOpRes $weightIndex, $vector), (PimMVMOp $weightIndex, $vector, (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 82401af..6948663 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -252,25 +252,6 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { rewriter.replaceOp(concatOp, concatenated); } -static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) { - SmallVector wvmmOps; - funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) { - if (wvmmOp->getParentOfType() || wvmmOp->getParentOfType()) - wvmmOps.push_back(wvmmOp); - }); - - for (auto wvmmOp : wvmmOps) { - rewriter.setInsertionPoint(wvmmOp); - auto outputType = cast(wvmmOp.getOutput().getType()); - Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult(); - rewriter.replaceOpWithNewOp(wvmmOp, - wvmmOp.getOutput().getType(), - rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()), - wvmmOp.getInput(), - outputBuffer); - } -} - static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { SmallVector mapOps; funcOp.walk([&](spatial::SpatMapOp mapOp) { @@ -736,7 +717,7 @@ void SpatialToPimPass::runOnOperation() { SmallVector coreOps; funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); for (auto coreOp : coreOps) { - if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) { + if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { signalPassFailure(); return; } @@ -745,15 +726,13 @@ void SpatialToPimPass::runOnOperation() { SmallVector coreBatchOps; funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); for (auto coreBatchOp : coreBatchOps) { - if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) { + if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { signalPassFailure(); return; } } } - lowerRemainingSpatialMathOps(funcOp, rewriter); - RewritePatternSet channelPatterns(ctx); populateWithGenerated(channelPatterns); if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 7e27354..c81e0d4 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -96,7 +96,7 @@ def PimEmptyManyOp : PimOp<"empty_many", []> { let summary = "Create many identical empty tensors"; let results = (outs - Variadic:$outputs + Variadic:$outputs ); let hasVerifier = 1; diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp index 85ed3a8..a12f8a1 100644 --- a/src/PIM/Dialect/Pim/PimOpsVerify.cpp +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -79,9 +79,9 @@ LogicalResult PimEmptyManyOp::verify() { return emitError("must produce at least one output"); Type firstType = getOutputs().front().getType(); - auto firstTensorType = dyn_cast(firstType); - if (!firstTensorType) - return emitError("outputs must all be ranked tensor types"); + auto firstShapedType = dyn_cast(firstType); + if (!firstShapedType || !firstShapedType.hasRank()) + return emitError("outputs must all be ranked shaped types"); for (Value output : getOutputs().drop_front()) if (output.getType() != firstType) diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index f5a11f3..38382f9 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -34,6 +34,15 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit .getOutput(); } +static FailureOr getBufferOrValue(RewriterBase& rewriter, + Value value, + const BufferizationOptions& options, + BufferizationState& state) { + if (isa(value.getType())) + return value; + return getBuffer(rewriter, value, options, state); +} + struct MemCopyHostToDevOpInterface : DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation* op, @@ -44,12 +53,12 @@ struct MemCopyHostToDevOpInterface auto deviceTarget = memCopyHostToDevOp.getDeviceTarget(); auto hostSource = memCopyHostToDevOp.getHostSource(); - auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state); + auto deviceTargetOpt = getBufferOrValue(rewriter, deviceTarget, options, state); if (failed(deviceTargetOpt)) return failure(); auto deviceTargetMemRef = *deviceTargetOpt; - auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state); + auto hostSourceOpt = getBufferOrValue(rewriter, hostSource, options, state); if (failed(hostSourceOpt)) return failure(); auto hostSourceMemRef = *hostSourceOpt; @@ -73,10 +82,10 @@ struct MemCopyHostToDevBatchOpInterface const BufferizationOptions& options, BufferizationState& state) const { auto memCopyHostToDevOp = cast(op); - auto deviceTargetOpt = getBuffer(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state); + auto deviceTargetOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state); if (failed(deviceTargetOpt)) return failure(); - auto hostSourceOpt = getBuffer(rewriter, memCopyHostToDevOp.getHostSource(), options, state); + auto hostSourceOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getHostSource(), options, state); if (failed(hostSourceOpt)) return failure(); @@ -101,13 +110,13 @@ struct MemCopyDevToHostOpInterface auto memCopyDevToHostOp = cast(op); auto hostTarget = memCopyDevToHostOp.getHostTarget(); - auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state); + auto hostTargetOpt = getBufferOrValue(rewriter, hostTarget, options, state); if (failed(hostTargetOpt)) return failure(); auto hostTargetMemRef = *hostTargetOpt; auto deviceSource = memCopyDevToHostOp.getDeviceSource(); - auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state); + auto deviceSourceOpt = getBufferOrValue(rewriter, deviceSource, options, state); if (failed(deviceSourceOpt)) return failure(); auto deviceSourceMemRef = *deviceSourceOpt; @@ -135,7 +144,7 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel(op); - auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); @@ -159,7 +168,7 @@ struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel(op); - auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); @@ -185,13 +194,11 @@ struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel(op); SmallVector outputBuffers; SmallVector resultTypes; - SmallVector tensorResults; outputBuffers.reserve(receiveOp.getOutputBuffers().size()); resultTypes.reserve(receiveOp.getOutputBuffers().size()); - tensorResults.reserve(receiveOp.getOutputBuffers().size()); for (Value outputBuffer : receiveOp.getOutputBuffers()) { - auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state); if (failed(outputBufferOpt)) return failure(); outputBuffers.push_back(*outputBufferOpt); @@ -200,15 +207,7 @@ struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel(tensorResult.getType()); - auto toTensor = - bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr()); - tensorResults.push_back(toTensor.getResult()); - } - - rewriter.replaceOp(receiveOp, tensorResults); + rewriter.replaceOp(receiveOp, newOp.getOutputs()); return success(); } }; @@ -226,13 +225,11 @@ struct ReceiveManyBatchOpInterface auto receiveOp = cast(op); SmallVector outputBuffers; SmallVector resultTypes; - SmallVector tensorResults; outputBuffers.reserve(receiveOp.getOutputBuffers().size()); resultTypes.reserve(receiveOp.getOutputBuffers().size()); - tensorResults.reserve(receiveOp.getOutputBuffers().size()); for (Value outputBuffer : receiveOp.getOutputBuffers()) { - auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state); if (failed(outputBufferOpt)) return failure(); outputBuffers.push_back(*outputBufferOpt); @@ -244,15 +241,7 @@ struct ReceiveManyBatchOpInterface TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr()); - - for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) { - auto tensorType = cast(tensorResult.getType()); - auto toTensor = - bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr()); - tensorResults.push_back(toTensor.getResult()); - } - - rewriter.replaceOp(receiveOp, tensorResults); + rewriter.replaceOp(receiveOp, newOp.getOutputs()); return success(); } }; @@ -267,7 +256,7 @@ struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel(op); - auto inputOpt = getBuffer(rewriter, extractRowsOp.getInput(), options, state); + auto inputOpt = getBufferOrValue(rewriter, extractRowsOp.getInput(), options, state); if (failed(inputOpt)) return failure(); @@ -277,7 +266,7 @@ struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel inputs; inputs.reserve(concatOp.getInputs().size()); for (Value input : concatOp.getInputs()) { - auto inputOpt = getBuffer(rewriter, input, options, state); + auto inputOpt = getBufferOrValue(rewriter, input, options, state); if (failed(inputOpt)) return failure(); inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter)); } - auto outputBufferOpt = getBuffer(rewriter, concatOp.getOutputBuffer(), options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); @@ -323,6 +312,31 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToAllocation(Operation* op, Value value) const { return true; } + + bool resultBufferizesToMemoryWrite(Operation* op, OpResult opResult, const AnalysisState& state) const { + return false; + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto emptyManyOp = cast(op); + + SmallVector resultTypes; + resultTypes.reserve(emptyManyOp.getOutputs().size()); + for (Value output : emptyManyOp.getOutputs()) { + auto shapedType = cast(output.getType()); + resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType())); + } + + replaceOpWithNewBufferizedOp(rewriter, emptyManyOp, TypeRange(resultTypes)); + return success(); + } +}; + struct MapOpInterface : BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } @@ -375,7 +389,7 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel(input.getType())) { - auto inputOpt = getBuffer(rewriter, input, options, state); + auto inputOpt = getBufferOrValue(rewriter, input, options, state); if (failed(inputOpt)) return failure(); inputs.push_back(*inputOpt); @@ -453,6 +467,16 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel(op); + bool alreadyBufferized = llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { + return isa(weight.getType()); + }) && llvm::all_of(coreBatchOp.getInputs(), [](Value input) { + return isa(input.getType()); + }) && llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) { + return isa(arg.getType()); + }); + if (alreadyBufferized) + return success(); + SmallVector weights; SmallVector inputs; weights.reserve(coreBatchOp.getWeights().size()); @@ -460,7 +484,7 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel(weight.getType())) { - auto weightOpt = getBuffer(rewriter, weight, options, state); + auto weightOpt = getBufferOrValue(rewriter, weight, options, state); if (failed(weightOpt)) return failure(); weights.push_back(*weightOpt); @@ -472,7 +496,7 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel(input.getType())) { - auto inputOpt = getBuffer(rewriter, input, options, state); + auto inputOpt = getBufferOrValue(rewriter, input, options, state); if (failed(inputOpt)) return failure(); inputs.push_back(*inputOpt); @@ -510,11 +534,11 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel(op); - auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state); + auto inputOpt = getBufferOrValue(rewriter, transposeOp.getInput(), options, state); if (failed(inputOpt)) return failure(); - auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, transposeOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); @@ -547,11 +571,11 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel(op); - auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state); + auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state); if (failed(inputOpt)) return failure(); - auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, vmmOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); @@ -574,11 +598,11 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel(op); - auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state); + auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state); if (failed(inputOpt)) return failure(); - auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); @@ -608,15 +632,15 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel(op); - auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state); + auto lhsOpt = getBufferOrValue(rewriter, binaryOp.getLhs(), options, state); if (failed(lhsOpt)) return failure(); - auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state); + auto rhsOpt = getBufferOrValue(rewriter, binaryOp.getRhs(), options, state); if (failed(rhsOpt)) return failure(); - auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, binaryOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); @@ -647,11 +671,11 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel(op); - auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state); + auto inputOpt = getBufferOrValue(rewriter, unaryOp.getInput(), options, state); if (failed(inputOpt)) return failure(); - auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state); + auto outputBufferOpt = getBufferOrValue(rewriter, unaryOp.getOutputBuffer(), options, state); if (failed(outputBufferOpt)) return failure(); @@ -664,6 +688,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel(*ctx); PimMapOp::attachInterface(*ctx); PimCoreBatchOp::attachInterface(*ctx); PimReceiveOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index c9421c0..99ed9ab 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -47,37 +47,26 @@ private: void PimBufferizationPass::runOnOperation() { auto moduleOp = getOperation(); - { - SmallVector emptyManyOps; - moduleOp.walk([&](pim::PimEmptyManyOp emptyManyOp) { emptyManyOps.push_back(emptyManyOp); }); - - IRRewriter rewriter(moduleOp.getContext()); - for (auto emptyManyOp : emptyManyOps) { - SmallVector replacementValues; - replacementValues.reserve(emptyManyOp.getOutputs().size()); - rewriter.setInsertionPoint(emptyManyOp); - for (Value output : emptyManyOp.getOutputs()) { - auto outputType = cast(output.getType()); - replacementValues.push_back( - tensor::EmptyOp::create(rewriter, emptyManyOp.getLoc(), outputType.getShape(), outputType.getElementType())); - } - rewriter.replaceOp(emptyManyOp, replacementValues); - } - } // Refactor this into a function { - auto funcOp = getPimEntryFunc(moduleOp); + auto funcOp = *getPimEntryFunc(moduleOp); - auto coreOps = llvm::to_vector(funcOp->getOps()); + SmallVector coreOps; + funcOp->walk([&](Operation* op) { + if (isa(op)) + coreOps.push_back(op); + }); MLIRContext* ctx = moduleOp.getContext(); // failableParallelForEach will run the lambda in parallel and stop if any thread fails - LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) { + LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](Operation* coreOp) { // Again, allocate state LOCALLY per thread/function bufferization::OneShotBufferizationOptions options; options.allowUnknownOps = true; + if (isa(coreOp)) + options.opFilter.denyOperation([coreOp](Operation* op) { return op == coreOp; }); bufferization::BufferizationState state; if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) { - coreOp.emitError("Failed to bufferize PIM and Spatial ops"); + coreOp->emitError("Failed to bufferize PIM and Spatial ops"); return failure(); } return success(); @@ -89,13 +78,16 @@ void PimBufferizationPass::runOnOperation() { } funcOp->walk([&](bufferization::ToTensorOp toTensorOp) { - if (llvm::isa_and_present(toTensorOp->getParentOp())) + if (llvm::isa_and_present(toTensorOp->getParentOp())) toTensorOp->setAttr("restrict", UnitAttr::get(ctx)); }); // One-Shot-Bufferization bufferization::OneShotBufferizationOptions options; options.allowUnknownOps = true; + options.opFilter.denyOperation([](Operation* op) { + return op->getParentOfType() || op->getParentOfType(); + }); bufferization::BufferizationState state; if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 9a074a6..43cda62 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -253,7 +253,7 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> { // Math //===----------------------------------------------------------------------===// -def SpatWeightedVMMOp : SpatOp<"wvmm", []> { +def SpatVMMOp : SpatOp<"wvmm", []> { let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins @@ -272,7 +272,7 @@ def SpatWeightedVMMOp : SpatOp<"wvmm", []> { }]; } -def SpatWeightedMVMOp : SpatOp<"Wmvm", []> { +def SpatMVMOp : SpatOp<"Wmvm", []> { let summary = "Matrix-vector multiplication within a weighted compute operation"; let arguments = (ins diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 6dc872b..528d2e7 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -20,7 +20,7 @@ namespace spatial { namespace { -inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, +inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { @@ -45,7 +45,7 @@ inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, return success(); } -inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, +inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { @@ -177,10 +177,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp } for (auto& bodyOp : block) { - if (auto wvmm = dyn_cast(&bodyOp)) + if (auto wvmm = dyn_cast(&bodyOp)) if (wvmm.getWeightIndex() < 0 || static_cast(wvmm.getWeightIndex()) >= weightsPerLane) return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane"); - if (auto wmvm = dyn_cast(&bodyOp)) + if (auto wmvm = dyn_cast(&bodyOp)) if (wmvm.getWeightIndex() < 0 || static_cast(wmvm.getWeightIndex()) >= weightsPerLane) return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane"); } @@ -189,10 +189,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp } // namespace -LogicalResult SpatWeightedMVMOp::verify() { +LogicalResult SpatMVMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) - return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op"); + return emitError("SpatMVMOp was not within a SpatCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); @@ -204,10 +204,10 @@ LogicalResult SpatWeightedMVMOp::verify() { return emitError("matrix rank must be 2 or 4"); } -LogicalResult SpatWeightedVMMOp::verify() { +LogicalResult SpatVMMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) - return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op"); + return emitError("SpatVMMOp was not within a SpatCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp index 332827d..6d1cea3 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp @@ -133,7 +133,7 @@ CrossbarUsage getComputeBodyCrossbarUsage(Region& body) { CrossbarUsage crossbarUsage = 0; for (auto& block : body) for (auto& op : block) - if (isa(op)) + if (isa(op)) crossbarUsage = checkedAdd(crossbarUsage, static_cast(1)); return crossbarUsage; } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp index a4c9465..7a71509 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Utils.hpp @@ -105,7 +105,7 @@ inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute CrossbarUsage crossbarUsage = 0; for (auto& region : spatCompute.getBody()) for (auto& inst : region) - if (llvm::isa(inst)) + if (llvm::isa(inst)) crossbarUsage = checkedAdd(crossbarUsage, static_cast(1)); return crossbarUsage; } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index df0422d..2d63429 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -838,9 +838,9 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) { for (auto& op : child.getBody().front()) { auto newInst = rewriter.clone(op, mapper); - if (auto weightedMvmOp = dyn_cast(newInst)) + if (auto weightedMvmOp = dyn_cast(newInst)) remapWeightIndex(weightedMvmOp); - if (auto weightedVmmOp = dyn_cast(newInst)) + if (auto weightedVmmOp = dyn_cast(newInst)) remapWeightIndex(weightedVmmOp); } @@ -884,9 +884,9 @@ void emitMotifProfile(func::FuncOp funcOp) { ComputeMotifInfo& info = computeInfos[index]; for (Operation& op : compute.getBody().front()) { info.instructionCount++; - if (isa(&op)) + if (isa(&op)) info.weightedMvmCount++; - if (isa(&op)) + if (isa(&op)) info.weightedVmmCount++; } if (info.weightedVmmCount > 0) { @@ -1617,13 +1617,13 @@ public: } Operation* clonedOp = cpuRewriter.clone(op, mapper); - if (auto oldWeightedMvmOp = dyn_cast(&op)) { - auto newWeightedMvmOp = cast(clonedOp); + if (auto oldWeightedMvmOp = dyn_cast(&op)) { + auto newWeightedMvmOp = cast(clonedOp); Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()]; newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight)); } - if (auto oldWeightedVmmOp = dyn_cast(&op)) { - auto newWeightedVmmOp = cast(clonedOp); + if (auto oldWeightedVmmOp = dyn_cast(&op)) { + auto newWeightedVmmOp = cast(clonedOp); Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()]; newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight)); } @@ -1643,22 +1643,22 @@ public: } Operation* clonedOp = cpuRewriter.clone(op, mapper); - if (auto oldWeightedMvmOp = dyn_cast(&op)) { + if (auto oldWeightedMvmOp = dyn_cast(&op)) { if (oldWeightedMvmOp.getWeightIndex() != 0) { task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); signalPassFailure(); return; } - auto newWeightedMvmOp = cast(clonedOp); + auto newWeightedMvmOp = cast(clonedOp); newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); } - if (auto oldWeightedVmmOp = dyn_cast(&op)) { + if (auto oldWeightedVmmOp = dyn_cast(&op)) { if (oldWeightedVmmOp.getWeightIndex() != 0) { task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); signalPassFailure(); return; } - auto newWeightedVmmOp = cast(clonedOp); + auto newWeightedVmmOp = cast(clonedOp); newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); } } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp index bd50aa2..925ba85 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -55,7 +55,7 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu [](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); }); } -static FailureOr analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) { +static FailureOr analyzeRegularChunk(spatial::SpatVMMOp startOp) { RegularChunk chunk; chunk.startOp = startOp.getOperation(); chunk.input = startOp.getInput(); @@ -376,7 +376,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) { auto compactInBlock = [&](Block& block) { for (auto it = block.begin(); it != block.end();) { - auto startOp = dyn_cast(&*it); + auto startOp = dyn_cast(&*it); if (!startOp) { ++it; continue; @@ -391,7 +391,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) { SmallVector run {*anchorChunk}; auto runIt = std::next(it, static_cast(anchorChunk->ops.size())); while (runIt != block.end()) { - auto candidateStart = dyn_cast(&*runIt); + auto candidateStart = dyn_cast(&*runIt); if (!candidateStart) break; @@ -425,7 +425,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) { for (auto compute : funcOp.getOps()) { Block& block = compute.getBody().front(); for (auto it = block.begin(); it != block.end();) { - auto wvmmOp = dyn_cast(&*it); + auto wvmmOp = dyn_cast(&*it); if (!wvmmOp) { ++it; continue; @@ -440,11 +440,11 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) { continue; } - SmallVector run; + SmallVector run; auto runIt = it; int64_t expectedRow = static_cast(rowResult.getResultNumber()); while (runIt != block.end()) { - auto current = dyn_cast(&*runIt); + auto current = dyn_cast(&*runIt); if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex() || current.getInput().getDefiningOp() != extractRowsOp || current.getInput().getType() != wvmmOp.getInput().getType() @@ -545,7 +545,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) { extractOffsets, extractSizes, extractStrides); - auto loopWvmm = spatial::SpatWeightedVMMOp::create( + auto loopWvmm = spatial::SpatVMMOp::create( rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult()); SmallVector insertOffsets = {iv, rewriter.getIndexAttr(0)}; diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index d91ca76..0514dbb 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -18,6 +18,7 @@ namespace { static bool isAddressOnlyHostOp(Operation* op) { return isa(defOp)) + if (isa(defOp)) return true; if (auto subview = dyn_cast(defOp)) { value = subview.getSource(); continue; } if (auto cast = dyn_cast(defOp)) { value = cast.getSource(); continue; } @@ -51,7 +52,7 @@ static bool isCodegenAddressableValue(Value value) { if (failed(resolvedAddress)) return false; return isa(resolvedAddress->base) - || isa(resolvedAddress->base.getDefiningOp()); + || isa(resolvedAddress->base.getDefiningOp()); } static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { @@ -184,7 +185,7 @@ private: continue; } - if (!isa(resolvedAddress->base.getDefiningOp())) { + if (!isa(resolvedAddress->base.getDefiningOp())) { op.emitOpError() << "operand #" << operandIndex << " must be backed by device-local memory; materialize host values with pim.memcp_hd"; hasFailure = true;