From b1272d22836937d01fdd941d6ff67e6fefa8c121 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 8 May 2026 14:21:45 +0200 Subject: [PATCH] fast pim bufferization using tensors --- src/PIM/Compiler/PimCodeGen.cpp | 58 ++-- src/PIM/Compiler/PimCodeGen.hpp | 3 + .../SpatialToPim/SpatialToPimPass.cpp | 291 +++++++++++++++++- src/PIM/Dialect/Pim/Pim.td | 34 ++ src/PIM/Dialect/Pim/PimOpsAsm.cpp | 72 ++++- src/PIM/Dialect/Pim/PimOpsVerify.cpp | 44 ++- .../OpBufferizationInterfaces.cpp | 120 +++++--- 7 files changed, 541 insertions(+), 81 deletions(-) diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index fe21b9d..459a1f1 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -178,7 +178,6 @@ void PimMemory::report(llvm::raw_ostream& file) { } } - void PimMemory::remove(mlir::Value val) { if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end()) globalMemEntriesMap.erase(removeIter); @@ -370,11 +369,21 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue "recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize()); } -void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const { - for (auto [outputBuffer, sourceCoreId] : llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds())) +void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, + const StaticValueKnowledge& knowledge) const { + for (auto [outputBuffer, sourceCoreId] : + llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds())) emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer)); } +void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, + const StaticValueKnowledge& knowledge) const { + size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge); + size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size(); + for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds())) + emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize); +} + void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const { emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize()); } @@ -384,7 +393,15 @@ void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticVa emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input)); } -void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const { +void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const { + size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge); + size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size(); + for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds())) + emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize); +} + +void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, + const StaticValueKnowledge& knowledge) const { auto inputType = cast(extractRowsOp.getInput().getType()); assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input"); @@ -393,13 +410,8 @@ void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge); for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers())) - emitMemCopyOp("lmv", - addressOf(outputBuffer, knowledge), - 0, - inputAddr, - rowIndex * rowSizeInBytes, - rowSizeInBytes, - "len"); + emitMemCopyOp( + "lmv", addressOf(outputBuffer, knowledge), 0, inputAddr, rowIndex * rowSizeInBytes, rowSizeInBytes, "len"); } void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const { @@ -742,10 +754,8 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor for (mlir::Value input : sendManyBatchOp.getInputs()) mappedInputs.push_back(mapper.lookup(input)); - pim::PimSendManyOp::create(builder, - sendManyBatchOp.getLoc(), - builder.getDenseI32ArrayAttr(laneTargetCoreIds), - ValueRange(mappedInputs)); + pim::PimSendManyOp::create( + builder, sendManyBatchOp.getLoc(), builder.getDenseI32ArrayAttr(laneTargetCoreIds), ValueRange(mappedInputs)); continue; } @@ -773,13 +783,13 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers()) mappedOutputBuffers.push_back(mapper.lookup(outputBuffer)); - auto scalarReceiveMany = - pim::PimReceiveManyOp::create(builder, - receiveManyBatchOp.getLoc(), - receiveManyBatchOp->getResultTypes(), - ValueRange(mappedOutputBuffers), - builder.getDenseI32ArrayAttr(laneSourceCoreIds)); - for (auto [originalOutput, scalarOutput] : llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs())) + auto scalarReceiveMany = pim::PimReceiveManyOp::create(builder, + receiveManyBatchOp.getLoc(), + receiveManyBatchOp->getResultTypes(), + ValueRange(mappedOutputBuffers), + builder.getDenseI32ArrayAttr(laneSourceCoreIds)); + for (auto [originalOutput, scalarOutput] : + llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs())) mapper.map(originalOutput, scalarOutput); continue; } @@ -904,10 +914,14 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenReceiveOp(receiveOp, knowledge); else if (auto receiveManyOp = dyn_cast(op)) coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge); + else if (auto receiveTensorOp = dyn_cast(op)) + coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge); else if (auto sendOp = dyn_cast(op)) coreCodeGen.codeGenSendOp(sendOp, knowledge); else if (auto sendManyOp = dyn_cast(op)) coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge); + else if (auto sendTensorOp = dyn_cast(op)) + coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge); else if (auto extractRowsOp = dyn_cast(op)) coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge); else if (auto concatOp = dyn_cast(op)) diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index ae11df9..62a7d5e 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -1,6 +1,7 @@ #pragma once #include "mlir/IR/Operation.h" + #include "llvm-project/clang/include/clang/Basic/LLVM.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/JSON.h" @@ -117,8 +118,10 @@ public: void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const; void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const; + void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const; void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const; void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const; + void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const; void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const; void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const; diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 6948663..7d572de 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -159,9 +159,7 @@ static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRe rewriter.eraseOp(sendManyOp); } -static SmallVector createManyEmptyTensorsLike(IRRewriter& rewriter, - Location loc, - TypeRange outputTypes) { +static SmallVector createManyEmptyTensorsLike(IRRewriter& rewriter, Location loc, TypeRange outputTypes) { SmallVector tensorTypes; tensorTypes.reserve(outputTypes.size()); for (Type outputType : outputTypes) @@ -177,7 +175,8 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size()); for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds()) sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); - SmallVector outputBuffers = createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes()); + SmallVector outputBuffers = + createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes()); auto receiveMany = PimReceiveManyOp::create(rewriter, receiveManyOp.getLoc(), @@ -199,10 +198,8 @@ static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendMa mappedInputs.reserve(sendManyBatchOp.getInputs().size()); for (Value input : sendManyBatchOp.getInputs()) mappedInputs.push_back(mapper.lookup(input)); - pim::PimSendManyBatchOp::create(rewriter, - sendManyBatchOp.getLoc(), - rewriter.getDenseI32ArrayAttr(targetCoreIds), - ValueRange(mappedInputs)); + pim::PimSendManyBatchOp::create( + rewriter, sendManyBatchOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), ValueRange(mappedInputs)); } static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp, @@ -272,6 +269,276 @@ static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { } } +static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { + SmallVector packedShape(elementType.getShape().begin(), elementType.getShape().end()); + packedShape[0] *= count; + return RankedTensorType::get(packedShape, elementType.getElementType()); +} + +static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) { + if (values.empty()) + return false; + + auto firstResult = dyn_cast(values.front()); + if (!firstResult) + return false; + + owner = firstResult.getOwner(); + startIndex = firstResult.getResultNumber(); + for (auto [index, value] : llvm::enumerate(values)) { + auto result = dyn_cast(value); + if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index) + return false; + } + + return true; +} + +static Value createPackedExtractRowsSlice( + pim::PimExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { + auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); + auto inputType = dyn_cast(extractRowsOp.getInput().getType()); + if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) + return {}; + + int64_t rowsPerValue = rowType.getDimSize(0); + if (ShapedType::isDynamic(rowsPerValue)) + return {}; + + auto packedType = getPackedTensorType(rowType, static_cast(count)); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(inputType.getRank()); + sizes.reserve(inputType.getRank()); + strides.reserve(inputType.getRank()); + + offsets.push_back(rewriter.getIndexAttr(static_cast(startIndex) * rowsPerValue)); + sizes.push_back(rewriter.getIndexAttr(static_cast(count) * rowsPerValue)); + strides.push_back(rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { + offsets.push_back(rewriter.getIndexAttr(0)); + sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); + strides.push_back(rewriter.getIndexAttr(1)); + } + + return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides) + .getResult(); +} + +static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) { + Operation* owner = nullptr; + unsigned startIndex = 0; + if (!getContiguousOpResults(values, owner, startIndex)) + return {}; + + if (auto extractRowsOp = dyn_cast(owner)) + return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast(values.size()), rewriter, loc); + + return {}; +} + +static Value createPackedReceiveTensor( + pim::PimReceiveManyOp receiveManyOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { + auto rowType = dyn_cast(receiveManyOp.getOutputs()[startIndex].getType()); + if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0) + return {}; + + auto packedType = getPackedTensorType(rowType, static_cast(count)); + auto outputBuffer = tensor::EmptyOp::create(rewriter, loc, packedType.getShape(), packedType.getElementType()); + + SmallVector sourceCoreIds; + sourceCoreIds.reserve(count); + ArrayRef allSourceCoreIds = receiveManyOp.getSourceCoreIds(); + for (unsigned index = 0; index < count; ++index) + sourceCoreIds.push_back(allSourceCoreIds[startIndex + index]); + + return pim::PimReceiveTensorOp::create( + rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds)) + .getOutput(); +} + +static Value +createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { + Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc); + if (!packedInput) + return {}; + + auto inputType = dyn_cast(mapOp.getInputs()[startIndex].getType()); + auto outputType = dyn_cast(mapOp.getOutputs()[startIndex].getType()); + if (!inputType || !outputType || !inputType.hasStaticShape() || !outputType.hasStaticShape() + || inputType.getRank() == 0 || outputType.getRank() == 0) + return {}; + + auto packedOutputType = getPackedTensorType(outputType, static_cast(count)); + auto packedInit = + tensor::EmptyOp::create(rewriter, loc, packedOutputType.getShape(), packedOutputType.getElementType()); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upper = arith::ConstantIndexOp::create(rewriter, loc, count); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto loop = scf::ForOp::create(rewriter, loc, zero, upper, step, ValueRange {packedInit.getResult()}); + + { + OpBuilder::InsertionGuard guard(rewriter); + Block* loopBlock = loop.getBody(); + rewriter.setInsertionPointToStart(loopBlock); + Value iv = loopBlock->getArgument(0); + Value acc = loopBlock->getArgument(1); + + int64_t inputRowsPerValue = inputType.getDimSize(0); + Value inputRowOffset = iv; + if (inputRowsPerValue != 1) { + auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, inputRowsPerValue); + inputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue); + } + + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.push_back(inputRowOffset); + extractSizes.push_back(rewriter.getIndexAttr(inputRowsPerValue)); + extractStrides.push_back(rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { + extractOffsets.push_back(rewriter.getIndexAttr(0)); + extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); + extractStrides.push_back(rewriter.getIndexAttr(1)); + } + auto inputSlice = tensor::ExtractSliceOp::create( + rewriter, loc, inputType, packedInput, extractOffsets, extractSizes, extractStrides); + + IRMapping mapping; + Block& body = mapOp.getBody().front(); + mapping.map(body.getArgument(0), inputSlice.getResult()); + for (Operation& bodyOp : body.without_terminator()) { + Operation* cloned = rewriter.clone(bodyOp, mapping); + for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults())) + mapping.map(originalResult, clonedResult); + rewriter.setInsertionPointAfter(cloned); + } + + auto yieldOp = cast(body.getTerminator()); + Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0)); + + int64_t outputRowsPerValue = outputType.getDimSize(0); + Value outputRowOffset = iv; + if (outputRowsPerValue != 1) { + auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, outputRowsPerValue); + outputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue); + } + + SmallVector insertOffsets; + SmallVector insertSizes; + SmallVector insertStrides; + insertOffsets.push_back(outputRowOffset); + insertSizes.push_back(rewriter.getIndexAttr(outputRowsPerValue)); + insertStrides.push_back(rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < outputType.getRank(); ++dim) { + insertOffsets.push_back(rewriter.getIndexAttr(0)); + insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim))); + insertStrides.push_back(rewriter.getIndexAttr(1)); + } + + auto inserted = + tensor::InsertSliceOp::create(rewriter, loc, mappedOutput, acc, insertOffsets, insertSizes, insertStrides); + scf::YieldOp::create(rewriter, loc, inserted.getResult()); + } + + return loop.getResult(0); +} + +static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { + SmallVector sendManyOps; + funcOp.walk([&](pim::PimSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); }); + for (auto sendManyOp : sendManyOps) { + if (sendManyOp.getInputs().empty()) + continue; + + rewriter.setInsertionPoint(sendManyOp); + Value packedInput = createPackedTensorForValues(sendManyOp.getInputs(), rewriter, sendManyOp.getLoc()); + if (!packedInput) + continue; + + pim::PimSendTensorOp::create(rewriter, sendManyOp.getLoc(), packedInput, sendManyOp.getTargetCoreIdsAttr()); + rewriter.eraseOp(sendManyOp); + } + + SmallVector concatOps; + funcOp.walk([&](pim::PimConcatOp concatOp) { concatOps.push_back(concatOp); }); + for (auto concatOp : concatOps) { + if (concatOp.getAxis() != 0 || concatOp.getInputs().empty()) + continue; + + SmallVector packedInputs; + bool changed = false; + rewriter.setInsertionPoint(concatOp); + + for (unsigned index = 0; index < concatOp.getInputs().size();) { + Value input = concatOp.getInputs()[index]; + auto result = dyn_cast(input); + if (!result) { + packedInputs.push_back(input); + ++index; + continue; + } + + Operation* owner = result.getOwner(); + unsigned startIndex = result.getResultNumber(); + unsigned endIndex = index + 1; + while (endIndex < concatOp.getInputs().size()) { + auto nextResult = dyn_cast(concatOp.getInputs()[endIndex]); + if (!nextResult || nextResult.getOwner() != owner + || nextResult.getResultNumber() != startIndex + (endIndex - index)) + break; + ++endIndex; + } + + unsigned count = endIndex - index; + Value packedInput; + if (auto mapOp = dyn_cast(owner)) + packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc()); + else if (auto receiveManyOp = dyn_cast(owner)) + packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc()); + else if (auto extractRowsOp = dyn_cast(owner)) + packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc()); + + if (packedInput) { + packedInputs.push_back(packedInput); + changed = true; + } + else { + for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex) + packedInputs.push_back(concatOp.getInputs()[oldIndex]); + } + + index = endIndex; + } + + if (!changed) + continue; + + auto newConcat = pim::PimConcatOp::create(rewriter, + concatOp.getLoc(), + concatOp.getOutput().getType(), + concatOp.getAxisAttr(), + ValueRange(packedInputs), + concatOp.getOutputBuffer()); + rewriter.replaceOp(concatOp, newConcat.getOutput()); + } + + auto eraseUnusedOps = [&](auto tag) { + using OpTy = decltype(tag); + SmallVector ops; + funcOp.walk([&](OpTy op) { ops.push_back(op); }); + for (auto op : llvm::reverse(ops)) + if (op->use_empty()) + rewriter.eraseOp(op); + }; + eraseUnusedOps(pim::PimMapOp {}); + eraseUnusedOps(pim::PimReceiveManyOp {}); + eraseUnusedOps(pim::PimExtractRowsOp {}); + eraseUnusedOps(pim::PimEmptyManyOp {}); +} + static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, SmallVectorImpl& helperChain, bool requireReturnUse = true) { @@ -399,21 +666,21 @@ static std::optional analyzeReturnUse(Value value) { } static std::optional analyzeConcatReturnUse(Value value) { - auto getConcatResult = [](Operation *op) -> Value { + auto getConcatResult = [](Operation* op) -> Value { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getResult(); if (auto pimConcat = dyn_cast(op)) return pimConcat.getOutput(); return {}; }; - auto getConcatAxis = [](Operation *op) -> std::optional { + auto getConcatAxis = [](Operation* op) -> std::optional { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getDim(); if (auto pimConcat = dyn_cast(op)) return pimConcat.getAxis(); return std::nullopt; }; - auto getConcatOperands = [](Operation *op) -> OperandRange { + auto getConcatOperands = [](Operation* op) -> OperandRange { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getOperands(); return cast(op).getInputs(); @@ -799,6 +1066,8 @@ void SpatialToPimPass::runOnOperation() { for (auto extractRowsOp : remainingExtractRowsOps) lowerExtractRows(extractRowsOp, rewriter); + compactPimTensorGroups(funcOp, rewriter); + // Dump to file for debug bool hasSpatialOps = false; moduleOp.walk([&](Operation* op) { diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index c81e0d4..e475256 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -133,6 +133,18 @@ def PimSendManyOp : PimOp<"send_many", []> { let hasCustomAssemblyFormat = 1; } +def PimSendTensorOp : PimOp<"send_tensor", []> { + let summary = "Send equal contiguous chunks of one tensor to target cores"; + + let arguments = (ins + PimTensor:$input, + DenseI32ArrayAttr:$targetCoreIds + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def PimSendBatchOp : PimOp<"send_batch", []> { let summary = "Send a per-lane tensor to target cores from a batched core"; @@ -203,6 +215,28 @@ def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> { let hasCustomAssemblyFormat = 1; } +def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> { + let summary = "Receive equal contiguous chunks from source cores into one tensor"; + + let arguments = (ins + PimTensor:$outputBuffer, + DenseI32ArrayAttr:$sourceCoreIds + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> { let summary = "Receive per-lane tensors from source cores into a batched core"; diff --git a/src/PIM/Dialect/Pim/PimOpsAsm.cpp b/src/PIM/Dialect/Pim/PimOpsAsm.cpp index dc491ca..7fb95be 100644 --- a/src/PIM/Dialect/Pim/PimOpsAsm.cpp +++ b/src/PIM/Dialect/Pim/PimOpsAsm.cpp @@ -4,8 +4,8 @@ #include "llvm/Support/LogicalResult.h" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; @@ -100,9 +100,9 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) { auto& builder = parser.getBuilder(); result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); - result.addAttribute("operandSegmentSizes", - builder.getDenseI32ArrayAttr( - {static_cast(weights.size()), static_cast(inputs.size())})); + result.addAttribute( + "operandSegmentSizes", + builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreIds) result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); @@ -267,6 +267,33 @@ ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) { return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); } +void PimSendTensorOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getInput()); + printCoreIdList(printer, "to", getTargetCoreIds()); + printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getInput().getType()); +} + +ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + SmallVector targetCoreIds; + + if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds) + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) + return failure(); + + if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "targetCoreIds cannot be specified both positionally and in attr-dict"); + if (!targetCoreIds.empty()) + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + + return parser.resolveOperand(input, inputType, result.operands); +} + void PimSendManyBatchOp::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueSequence(printer, getInputs()); @@ -333,6 +360,43 @@ ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) return success(); } +void PimReceiveTensorOp::print(OpAsmPrinter& printer) { + printCoreIdList(printer, "from", getSourceCoreIds()); + printer << " into "; + printOpenDelimiter(printer, ListDelimiter::Paren); + printer.printOperand(getOutputBuffer()); + printCloseDelimiter(printer, ListDelimiter::Paren); + printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getOutputBuffer().getType()); + printer << " -> "; + printer.printType(getOutput().getType()); +} + +ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand outputBuffer; + Type outputBufferType; + Type outputType; + SmallVector sourceCoreIds; + + if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen() + || parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) + || parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow() + || parser.parseType(outputType)) + return failure(); + + if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "sourceCoreIds cannot be specified both positionally and in attr-dict"); + if (!sourceCoreIds.empty()) + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + + if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands)) + return failure(); + result.addTypes(outputType); + return success(); +} + void PimReceiveBatchOp::print(OpAsmPrinter& printer) { printCoreIdList(printer, "from", getSourceCoreIds()); printer << " into "; diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp index a12f8a1..cae01e5 100644 --- a/src/PIM/Dialect/Pim/PimOpsVerify.cpp +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -48,12 +48,32 @@ static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types return op->emitError() << kind << " values must all have the same type"; if (firstIsTensor != isa(type) || firstIsMemRef != isa(type)) return op->emitError() << kind << " values must all use the same shaped container kind"; - if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape()) + if (firstShapedType.getElementType() != shapedType.getElementType() + || firstShapedType.getShape() != shapedType.getShape()) return op->emitError() << kind << " values must all have the same shape and element type"; } return success(); } +static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef coreIds, StringRef kind) { + if (coreIds.empty()) + return op->emitError() << kind << " must carry at least one chunk"; + + auto shapedType = dyn_cast(type); + if (!shapedType || !shapedType.hasStaticShape()) + return op->emitError() << kind << " requires a static shaped tensor or memref"; + + int64_t elementBits = shapedType.getElementTypeBitWidth(); + if (elementBits <= 0 || elementBits % 8 != 0) + return op->emitError() << kind << " requires byte-sized elements"; + + int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; + if (totalBytes % static_cast(coreIds.size()) != 0) + return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids"; + + return success(); +} + static FailureOr getParentBatchLaneCount(Operation* op) { auto coreBatchOp = op->getParentOfType(); if (!coreBatchOp) @@ -61,9 +81,7 @@ static FailureOr getParentBatchLaneCount(Operation* op) { return coreBatchOp.getLaneCount(); } -static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, - ArrayRef coreIds, - size_t valueCount) { +static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef coreIds, size_t valueCount) { auto laneCount = getParentBatchLaneCount(op); if (failed(laneCount)) return op->emitError("must be nested inside pim.core_batch"); @@ -109,7 +127,8 @@ LogicalResult PimMapOp::verify() { Block& block = getBody().front(); if (block.getNumArguments() != 1) return emitError("body must have exactly one block argument"); - if (block.getArgument(0).getType() != inputType) + if (failed(verifyCompatibleShapedTypes( + getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type"))) return emitError("body block argument type must match input type"); auto yieldOp = dyn_cast_or_null(block.getTerminator()); @@ -117,7 +136,8 @@ LogicalResult PimMapOp::verify() { return emitError("body must terminate with pim.yield"); if (yieldOp.getNumOperands() != 1) return emitError("body yield must produce exactly one value"); - if (yieldOp.getOperand(0).getType() != outputType) + if (failed(verifyCompatibleShapedTypes( + getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type"))) return emitError("body yield type must match output type"); return success(); @@ -129,6 +149,10 @@ LogicalResult PimSendManyOp::verify() { return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many"); } +LogicalResult PimSendTensorOp::verify() { + return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor"); +} + LogicalResult PimSendManyBatchOp::verify() { if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size()))) return failure(); @@ -153,6 +177,14 @@ LogicalResult PimReceiveManyOp::verify() { return success(); } +LogicalResult PimReceiveTensorOp::verify() { + if (failed(verifyCompatibleShapedTypes( + getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) + return failure(); + + return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor"); +} + LogicalResult PimReceiveManyBatchOp::verify() { if (getOutputBuffers().size() != getOutputs().size()) return emitError("number of output buffers must match the number of outputs"); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 38382f9..9f24c82 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -34,10 +34,8 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit .getOutput(); } -static FailureOr getBufferOrValue(RewriterBase& rewriter, - Value value, - const BufferizationOptions& options, - BufferizationState& state) { +static FailureOr +getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) { if (isa(value.getType())) return value; return getBuffer(rewriter, value, options, state); @@ -205,13 +203,37 @@ struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModelgetType()); } - auto newOp = PimReceiveManyOp::create( - rewriter, receiveOp.getLoc(), TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr()); + auto newOp = PimReceiveManyOp::create(rewriter, + receiveOp.getLoc(), + TypeRange(resultTypes), + ValueRange(outputBuffers), + receiveOp.getSourceCoreIdsAttr()); rewriter.replaceOp(receiveOp, newOp.getOutputs()); return success(); } }; +struct ReceiveTensorOpInterface +: DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto receiveOp = cast(op); + auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) + return failure(); + + replaceOpWithNewBufferizedOp( + rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr()); + return success(); + } +}; + struct ReceiveManyBatchOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { @@ -337,6 +359,30 @@ struct EmptyManyOpInterface : BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } + + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto sendOp = cast(op); + auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state); + if (failed(inputOpt)) + return failure(); + + replaceOpWithNewBufferizedOp( + rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr()); + return success(); + } +}; + struct MapOpInterface : BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } @@ -349,23 +395,26 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel(op); auto bbArg = dyn_cast(value); - if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty()) + if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 + || mapOp.getInputs().empty()) return {}; - return {{&mapOp->getOpOperand(0), BufferRelation::Equivalent}}; + return { + {&mapOp->getOpOperand(0), BufferRelation::Equivalent} + }; } bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; } - FailureOr - getBufferType(Operation* op, - Value value, - const BufferizationOptions& options, - const BufferizationState& state, - SmallVector& invocationStack) const { + FailureOr getBufferType(Operation* op, + Value value, + const BufferizationOptions& options, + const BufferizationState& state, + SmallVector& invocationStack) const { auto mapOp = cast(op); auto bbArg = dyn_cast(value); - if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty()) + if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 + || mapOp.getInputs().empty()) return failure(); auto inputType = dyn_cast(mapOp.getInputs().front().getType()); @@ -417,13 +466,9 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return true; - } + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } - bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return false; - } + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return {}; @@ -436,19 +481,18 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModelgetOpOperand(inputOperandIndex), BufferRelation::Equivalent}}; + return { + {&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent} + }; } - bool isWritable(Operation* op, Value value, const AnalysisState& state) const { - return false; - } + bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; } - FailureOr - getBufferType(Operation* op, - Value value, - const BufferizationOptions& options, - const BufferizationState& state, - SmallVector& invocationStack) const { + FailureOr getBufferType(Operation* op, + Value value, + const BufferizationOptions& options, + const BufferizationState& state, + SmallVector& invocationStack) const { auto coreBatchOp = cast(op); auto bbArg = dyn_cast(value); if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front()) @@ -467,13 +511,11 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel(op); - bool alreadyBufferized = llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { - return isa(weight.getType()); - }) && llvm::all_of(coreBatchOp.getInputs(), [](Value input) { - return isa(input.getType()); - }) && llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) { - return isa(arg.getType()); - }); + bool alreadyBufferized = + llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa(weight.getType()); }) + && llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa(input.getType()); }) + && llvm::all_of(coreBatchOp.getBody().front().getArguments(), + [](BlockArgument arg) { return isa(arg.getType()); }); if (alreadyBufferized) return success(); @@ -693,8 +735,10 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimCoreBatchOp::attachInterface(*ctx); PimReceiveOp::attachInterface(*ctx); PimReceiveManyOp::attachInterface(*ctx); + PimReceiveTensorOp::attachInterface(*ctx); PimReceiveBatchOp::attachInterface(*ctx); PimReceiveManyBatchOp::attachInterface(*ctx); + PimSendTensorOp::attachInterface(*ctx); PimExtractRowsOp::attachInterface(*ctx); PimConcatOp::attachInterface(*ctx); PimMemCopyHostToDevOp::attachInterface(*ctx);