From 5ff364027bdc2210ed8ced2e84a70980cda7e5eb Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 11 May 2026 14:38:13 +0200 Subject: [PATCH] big cleanup: remove remaining pim many operations, simplify bufferization logic --- src/PIM/Common/IR/AddressAnalysis.cpp | 2 +- src/PIM/Compiler/PimCodeGen.cpp | 279 ++++++------------ src/PIM/Compiler/PimCodeGen.hpp | 5 - .../SpatialToPim/SpatialToPimPass.cpp | 248 ++++++++++------ src/PIM/Dialect/Pim/Pim.td | 126 -------- src/PIM/Dialect/Pim/PimOpsAsm.cpp | 230 --------------- src/PIM/Dialect/Pim/PimOpsVerify.cpp | 180 ----------- .../OpBufferizationInterfaces.cpp | 253 +++++----------- .../Bufferization/PimBufferizationPass.cpp | 82 +---- .../MaterializeHostConstantsPass.cpp | 63 +--- src/PIM/Pass/PimCodegen/VerificationPass.cpp | 76 ++++- src/PIM/TODO.md | 10 - 12 files changed, 390 insertions(+), 1164 deletions(-) delete mode 100644 src/PIM/TODO.md diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index a168cf9..029f50f 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -228,7 +228,7 @@ llvm::FailureOr resolveContiguousAddressImpl(mlir::Va continue; } - if (mlir::isa(definingOp)) + if (mlir::isa(definingOp)) return ResolvedContiguousAddress {value, byteOffset}; return mlir::failure(); diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 459a1f1..288798c 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -42,6 +42,79 @@ static size_t getValueSizeInBytes(mlir::Value value) { return type.getNumElements() * type.getElementTypeBitWidth() / 8; } +struct DenseWeightView { + DenseElementsAttr denseAttr; + SmallVector shape; + SmallVector strides; + int64_t offset = 0; +}; + +static SmallVector computeRowMajorStridesForShape(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + for (int64_t index = static_cast(shape.size()) - 2; index >= 0; --index) + strides[index] = strides[index + 1] * shape[index + 1]; + return strides; +} + +static bool allStaticSubviewParts(memref::SubViewOp subview) { + return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) + && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) + && llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); }); +} + +static FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) { + SmallVector subviews; + mlir::Value current = weight; + memref::GetGlobalOp getGlobalOp; + + while (true) { + Operation* defOp = current.getDefiningOp(); + if (!defOp) + return failure(); + if ((getGlobalOp = dyn_cast(defOp))) + break; + if (auto subview = dyn_cast(defOp)) { + if (!allStaticSubviewParts(subview)) + return failure(); + subviews.push_back(subview); + current = subview.getSource(); + continue; + } + if (auto cast = dyn_cast(defOp)) { + current = cast.getSource(); + continue; + } + return failure(); + } + + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp || !globalOp.getInitialValue()) + return failure(); + + auto denseAttr = dyn_cast(*globalOp.getInitialValue()); + if (!denseAttr) + return failure(); + + DenseWeightView view; + view.denseAttr = denseAttr; + view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); + view.strides = computeRowMajorStridesForShape(view.shape); + + for (memref::SubViewOp subview : llvm::reverse(subviews)) { + SmallVector nextStrides; + nextStrides.reserve(subview.getStaticStrides().size()); + for (auto [offset, stride, sourceStride] : + llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) { + view.offset += offset * sourceStride; + nextStrides.push_back(stride * sourceStride); + } + view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); + view.strides = std::move(nextStrides); + } + + return view; +} + MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { auto type = cast(value.getType()); assert("Only static shape is supported" && type.hasStaticShape()); @@ -97,11 +170,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { if (!allocOp->getParentOfType()) gatherMemEntry(allocOp.getResult()); }); - funcOp.walk([&](pim::PimEmptyManyOp emptyManyOp) { - if (!emptyManyOp->getParentOfType() && !emptyManyOp->getParentOfType()) - for (mlir::Value output : emptyManyOp.getOutputs()) - gatherMemEntry(output); - }); allocateGatheredMemory(); @@ -111,10 +179,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { void PimMemory::allocateCore(Operation* op) { op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); }); - op->walk([&](pim::PimEmptyManyOp emptyManyOp) { - for (mlir::Value output : emptyManyOp.getOutputs()) - gatherMemEntry(output); - }); allocateGatheredMemory(); } @@ -369,13 +433,6 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue "recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize()); } -void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, - const StaticValueKnowledge& knowledge) const { - for (auto [outputBuffer, sourceCoreId] : - llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds())) - emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer)); -} - void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const { size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge); @@ -388,11 +445,6 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize()); } -void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const { - for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) - emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input)); -} - void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const { size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge); size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size(); @@ -400,20 +452,6 @@ void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const St emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize); } -void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, - const StaticValueKnowledge& knowledge) const { - auto inputType = cast(extractRowsOp.getInput().getType()); - assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input"); - - size_t elementSize = inputType.getElementTypeBitWidth() / 8; - size_t rowSizeInBytes = static_cast(inputType.getDimSize(1)) * elementSize; - size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge); - - for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers())) - emitMemCopyOp( - "lmv", addressOf(outputBuffer, knowledge), 0, inputAddr, rowIndex * rowSizeInBytes, rowSizeInBytes, "len"); -} - void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const { auto outputType = cast(concatOp.getOutputBuffer().getType()); assert(outputType.hasStaticShape() && "concat codegen requires static output shape"); @@ -742,23 +780,6 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor continue; } - if (auto sendManyBatchOp = dyn_cast(op)) { - SmallVector laneTargetCoreIds; - laneTargetCoreIds.reserve(sendManyBatchOp.getInputs().size()); - for (auto valueIndex : llvm::seq(0, sendManyBatchOp.getInputs().size())) - laneTargetCoreIds.push_back( - sendManyBatchOp.getTargetCoreIds()[valueIndex * laneCount + static_cast(lane)]); - - SmallVector mappedInputs; - mappedInputs.reserve(sendManyBatchOp.getInputs().size()); - for (mlir::Value input : sendManyBatchOp.getInputs()) - mappedInputs.push_back(mapper.lookup(input)); - - pim::PimSendManyOp::create( - builder, sendManyBatchOp.getLoc(), builder.getDenseI32ArrayAttr(laneTargetCoreIds), ValueRange(mappedInputs)); - continue; - } - if (auto receiveBatchOp = dyn_cast(op)) { auto scalarReceive = pim::PimReceiveOp::create(builder, @@ -771,29 +792,6 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor continue; } - if (auto receiveManyBatchOp = dyn_cast(op)) { - SmallVector laneSourceCoreIds; - laneSourceCoreIds.reserve(receiveManyBatchOp.getOutputs().size()); - for (auto valueIndex : llvm::seq(0, receiveManyBatchOp.getOutputs().size())) - laneSourceCoreIds.push_back( - receiveManyBatchOp.getSourceCoreIds()[valueIndex * laneCount + static_cast(lane)]); - - SmallVector mappedOutputBuffers; - mappedOutputBuffers.reserve(receiveManyBatchOp.getOutputBuffers().size()); - for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers()) - mappedOutputBuffers.push_back(mapper.lookup(outputBuffer)); - - auto scalarReceiveMany = pim::PimReceiveManyOp::create(builder, - receiveManyBatchOp.getLoc(), - receiveManyBatchOp->getResultTypes(), - ValueRange(mappedOutputBuffers), - builder.getDenseI32ArrayAttr(laneSourceCoreIds)); - for (auto [originalOutput, scalarOutput] : - llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs())) - mapper.map(originalOutput, scalarOutput); - continue; - } - if (auto memcpBatchOp = dyn_cast(op)) { mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource()); if (!hostSource) @@ -912,18 +910,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenLmvOp(lmvOp, knowledge); else if (auto receiveOp = dyn_cast(op)) coreCodeGen.codeGenReceiveOp(receiveOp, knowledge); - else if (auto receiveManyOp = dyn_cast(op)) - coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge); else if (auto receiveTensorOp = dyn_cast(op)) coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge); else if (auto sendOp = dyn_cast(op)) coreCodeGen.codeGenSendOp(sendOp, knowledge); - else if (auto sendManyOp = dyn_cast(op)) - coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge); else if (auto sendTensorOp = dyn_cast(op)) coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge); - else if (auto extractRowsOp = dyn_cast(op)) - coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge); else if (auto concatOp = dyn_cast(op)) coreCodeGen.codeGenConcatOp(concatOp, knowledge); else if (auto vmmOp = dyn_cast(op)) @@ -954,8 +946,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); else if (auto getGlobalOp = dyn_cast(op)) coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); - else if (isa(op)) - return success(); else { op.emitError("Unsupported codegen for this operation"); op.dump(); @@ -967,84 +957,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { return failed(result) ? -1 : static_cast(processedOperations); } -/// Write crossbar weight matrices as padded binary files for a single core. -static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp, - pim::PimCoreOp coreOp, - StringRef coreWeightsDirPath, - json::Array& xbarsPerGroup) { - int64_t xbarSize = crossbarSize.getValue(); - std::error_code errorCode; - size_t weightIndex = 0; - - for (auto weight : coreOp.getWeights()) { - xbarsPerGroup.push_back(weightIndex); - - auto getGlobalOp = weight.getDefiningOp(); - if (!getGlobalOp) { - coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex)); - weightIndex++; - continue; - } - - auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); - if (!globalOp) { - coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex)); - weightIndex++; - continue; - } - - auto initialValue = globalOp.getInitialValue(); - if (!initialValue) { - coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex)); - weightIndex++; - continue; - } - - auto denseAttr = dyn_cast(*initialValue); - if (!denseAttr) { - coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex)); - weightIndex++; - continue; - } - - auto type = denseAttr.getType(); - auto shape = type.getShape(); - assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); - int64_t numRows = shape[0]; - int64_t numCols = shape[1]; - assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); - - size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8; - - auto weightFilePath = (coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin").str(); - raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); - if (errorCode) { - errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; - return InvalidOutputFileAccess; - } - - uint64_t zero = 0; - for (int64_t row = 0; row < xbarSize; row++) { - for (int64_t col = 0; col < xbarSize; col++) { - if (row < numRows && col < numCols) { - int64_t index = row * numCols + col; - APInt bits = denseAttr.getValues()[index].bitcastToAPInt(); - uint64_t word = bits.getZExtValue(); - weightFileStream.write(reinterpret_cast(&word), elementByteWidth); - } - else { - weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); - } - } - } - - weightFileStream.close(); - weightIndex++; - } - - return CompilerSuccess; -} - llvm::DenseMap> createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { ModuleOp moduleOp = funcOp->getParentOfType(); @@ -1079,45 +991,31 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { } mlir::Value weight = coreOp.getWeights()[index]; - auto getGlobalOp = weight.getDefiningOp(); - if (!getGlobalOp) { + auto weightView = resolveDenseWeightView(moduleOp, weight); + if (failed(weightView)) { coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); - assert(!getGlobalOp && "Weight is not from a memref.get_global"); + assert(succeeded(weightView) && "Weight is not from a dense memref.global view"); } - auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); - if (!globalOp) { - coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index)); - assert(!globalOp && "Could not find memref.global"); - } + if (mapCoreWeightToFileName[coreId].contains(weight)) + continue; - auto initialValue = globalOp.getInitialValue(); - if (!initialValue) { - coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index)); - assert(!initialValue && "memref.global has no initial value"); - } - - auto denseAttr = dyn_cast(*initialValue); - if (!denseAttr) { - coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index)); - assert(!denseAttr && "memref.global initial value is not dense"); - } - - if (mapGlobalOpToFileName.contains(globalOp)) { + auto getGlobalOp = weight.getDefiningOp(); + auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {}; + if (globalOp && mapGlobalOpToFileName.contains(globalOp)) { auto& fileName = mapGlobalOpToFileName[globalOp]; - std::pair weightToFile = {weight, fileName}; - mapCoreWeightToFileName[coreId].insert(weightToFile); + mapCoreWeightToFileName[coreId].insert({weight, fileName}); continue; } - auto type = denseAttr.getType(); - auto shape = type.getShape(); + DenseElementsAttr denseAttr = weightView->denseAttr; + ArrayRef shape = weightView->shape; assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); int64_t numRows = shape[0]; int64_t numCols = shape[1]; assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); - size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8; + size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8; std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); @@ -1132,8 +1030,8 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { for (int64_t row = 0; row < xbarSize; row++) { for (int64_t col = 0; col < xbarSize; col++) { if (row < numRows && col < numCols) { - int64_t index = row * numCols + col; - APInt bits = denseAttr.getValues()[index].bitcastToAPInt(); + int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1]; + APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); uint64_t word = bits.getZExtValue(); weightFileStream.write(reinterpret_cast(&word), elementByteWidth); } @@ -1144,7 +1042,8 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { } weightFileStream.close(); - mapGlobalOpToFileName.insert({globalOp, newFileName}); + if (globalOp) + mapGlobalOpToFileName.insert({globalOp, newFileName}); mapCoreWeightToFileName[coreId].insert({weight, newFileName}); } } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 62a7d5e..d03704b 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -24,8 +24,6 @@ class PimMemory { llvm::SmallVector, 32> memEntries; llvm::SmallDenseMap& globalMemEntriesMap; - size_t maxSize = 0; // 0 for unbounded memory - size_t startAddress = 0; size_t minAlignment = 4; size_t firstAvailableAddress = 0; @@ -117,12 +115,9 @@ public: void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const; void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const; - void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const; void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const; void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const; - void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const; void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const; - void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const; void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const; template diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 7d572de..0ce4618 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -116,7 +116,7 @@ static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co SmallVector coreIds; coreIds.reserve(static_cast(computeBatchOp.getLaneCount())); - for (int32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) + for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) coreIds.push_back(static_cast(fallbackCoreId++)); return coreIds; } @@ -150,40 +150,33 @@ static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewri static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(sendManyOp); - SmallVector targetCoreIds; - targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size()); - for (int32_t targetCoreId : sendManyOp.getTargetCoreIds()) - targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); - PimSendManyOp::create( - rewriter, sendManyOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), sendManyOp.getInputs()); + for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) { + PimSendOp::create(rewriter, + sendManyOp.getLoc(), + input, + getTensorSizeInBytesAttr(rewriter, input), + rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId))); + } rewriter.eraseOp(sendManyOp); } -static SmallVector createManyEmptyTensorsLike(IRRewriter& rewriter, Location loc, TypeRange outputTypes) { - SmallVector tensorTypes; - tensorTypes.reserve(outputTypes.size()); - for (Type outputType : outputTypes) - tensorTypes.push_back(outputType); - - auto emptyMany = pim::PimEmptyManyOp::create(rewriter, loc, TypeRange(tensorTypes)); - return SmallVector(emptyMany.getOutputs().begin(), emptyMany.getOutputs().end()); -} - static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(receiveManyOp); - SmallVector sourceCoreIds; - sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size()); - for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds()) - sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); - SmallVector outputBuffers = - createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes()); - - auto receiveMany = PimReceiveManyOp::create(rewriter, - receiveManyOp.getLoc(), - receiveManyOp.getResultTypes(), - ValueRange(outputBuffers), - rewriter.getDenseI32ArrayAttr(sourceCoreIds)); - rewriter.replaceOp(receiveManyOp, receiveMany.getOutputs()); + SmallVector replacements; + replacements.reserve(receiveManyOp.getNumResults()); + for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) { + auto outputType = cast(output.getType()); + Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType).getResult(); + replacements.push_back( + PimReceiveOp::create(rewriter, + receiveManyOp.getLoc(), + output.getType(), + outputBuffer, + getTensorSizeInBytesAttr(rewriter, output), + rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId))) + .getOutput()); + } + rewriter.replaceOp(receiveManyOp, replacements); } static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp, @@ -198,8 +191,17 @@ static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendMa mappedInputs.reserve(sendManyBatchOp.getInputs().size()); for (Value input : sendManyBatchOp.getInputs()) mappedInputs.push_back(mapper.lookup(input)); - pim::PimSendManyBatchOp::create( - rewriter, sendManyBatchOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), ValueRange(mappedInputs)); + for (auto [valueIndex, input] : llvm::enumerate(mappedInputs)) { + SmallVector laneTargetCoreIds; + laneTargetCoreIds.reserve(laneCount); + for (int32_t lane = 0; lane < laneCount; ++lane) + laneTargetCoreIds.push_back(targetCoreIds[valueIndex * laneCount + lane]); + pim::PimSendBatchOp::create(rewriter, + sendManyBatchOp.getLoc(), + input, + getTensorSizeInBytesAttr(rewriter, input), + rewriter.getDenseI32ArrayAttr(laneTargetCoreIds)); + } } static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp, @@ -210,29 +212,44 @@ static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size()); for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds()) sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); - SmallVector outputBuffers = - createManyEmptyTensorsLike(rewriter, receiveManyBatchOp.getLoc(), receiveManyBatchOp.getResultTypes()); - auto receiveMany = pim::PimReceiveManyBatchOp::create(rewriter, - receiveManyBatchOp.getLoc(), - receiveManyBatchOp.getResultTypes(), - ValueRange(outputBuffers), - rewriter.getDenseI32ArrayAttr(sourceCoreIds)); - for (auto [output, received] : llvm::zip(receiveManyBatchOp.getOutputs(), receiveMany.getOutputs())) + for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) { + auto outputType = cast(output.getType()); + Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType).getResult(); + SmallVector laneSourceCoreIds; + laneSourceCoreIds.reserve(laneCount); + for (int32_t lane = 0; lane < laneCount; ++lane) + laneSourceCoreIds.push_back(sourceCoreIds[valueIndex * laneCount + lane]); + + auto received = pim::PimReceiveBatchOp::create(rewriter, + receiveManyBatchOp.getLoc(), + output.getType(), + outputBuffer, + getTensorSizeInBytesAttr(rewriter, output), + rewriter.getDenseI32ArrayAttr(laneSourceCoreIds)) + .getOutput(); mapper.map(output, received); + } } static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { rewriter.setInsertionPoint(extractRowsOp); - SmallVector outputBuffers = - createManyEmptyTensorsLike(rewriter, extractRowsOp.getLoc(), extractRowsOp.getResultTypes()); - - auto extractRows = pim::PimExtractRowsOp::create(rewriter, - extractRowsOp.getLoc(), - extractRowsOp.getResultTypes(), - extractRowsOp.getInput(), - ValueRange(outputBuffers)); - rewriter.replaceOp(extractRowsOp, extractRows.getOutputs()); + auto inputType = cast(extractRowsOp.getInput().getType()); + SmallVector replacements; + replacements.reserve(extractRowsOp.getNumResults()); + for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) { + auto outputType = cast(output.getType()); + SmallVector offsets = { + rewriter.getIndexAttr(static_cast(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)), + rewriter.getIndexAttr(inputType.getDimSize(1))}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + replacements.push_back( + tensor::ExtractSliceOp::create( + rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides) + .getResult()); + } + rewriter.replaceOp(extractRowsOp, replacements); } static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { @@ -258,14 +275,26 @@ static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { for (auto mapOp : mapOps) { Block& body = mapOp.getBody().front(); - rewriter.setInsertionPoint(mapOp); - auto pimMap = pim::PimMapOp::create(rewriter, mapOp.getLoc(), mapOp.getResultTypes(), mapOp.getInputs()); - rewriter.inlineRegionBefore(mapOp.getBody(), pimMap.getBody(), pimMap.getBody().begin()); - auto yieldOp = cast(body.getTerminator()); - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp(yieldOp, yieldOp.getOutputs()); - rewriter.replaceOp(mapOp, pimMap.getOutputs()); + SmallVector replacements; + replacements.reserve(mapOp.getInputs().size()); + rewriter.setInsertionPoint(mapOp); + + for (Value input : mapOp.getInputs()) { + IRMapping mapping; + mapping.map(body.getArgument(0), input); + + for (Operation& bodyOp : body.without_terminator()) { + Operation* cloned = rewriter.clone(bodyOp, mapping); + for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults())) + mapping.map(originalResult, clonedResult); + rewriter.setInsertionPointAfter(cloned); + } + + replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0))); + } + + rewriter.replaceOp(mapOp, replacements); } } @@ -295,7 +324,7 @@ static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigne } static Value createPackedExtractRowsSlice( - pim::PimExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { + spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); auto inputType = dyn_cast(extractRowsOp.getInput().getType()); if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) @@ -332,14 +361,17 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter if (!getContiguousOpResults(values, owner, startIndex)) return {}; - if (auto extractRowsOp = dyn_cast(owner)) + if (auto extractRowsOp = dyn_cast(owner)) return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast(values.size()), rewriter, loc); return {}; } -static Value createPackedReceiveTensor( - pim::PimReceiveManyOp receiveManyOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { +static Value createPackedReceiveTensor(spatial::SpatChannelReceiveManyOp receiveManyOp, + unsigned startIndex, + unsigned count, + IRRewriter& rewriter, + Location loc) { auto rowType = dyn_cast(receiveManyOp.getOutputs()[startIndex].getType()); if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0) return {}; @@ -351,15 +383,15 @@ static Value createPackedReceiveTensor( sourceCoreIds.reserve(count); ArrayRef allSourceCoreIds = receiveManyOp.getSourceCoreIds(); for (unsigned index = 0; index < count; ++index) - sourceCoreIds.push_back(allSourceCoreIds[startIndex + index]); + sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(allSourceCoreIds[startIndex + index])); return pim::PimReceiveTensorOp::create( rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds)) .getOutput(); } -static Value -createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { +static Value createPackedMapTensor( + spatial::SpatMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc); if (!packedInput) return {}; @@ -416,7 +448,7 @@ createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, rewriter.setInsertionPointAfter(cloned); } - auto yieldOp = cast(body.getTerminator()); + auto yieldOp = cast(body.getTerminator()); Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0)); int64_t outputRowsPerValue = outputType.getDimSize(0); @@ -446,9 +478,9 @@ createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, return loop.getResult(0); } -static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { - SmallVector sendManyOps; - funcOp.walk([&](pim::PimSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); }); +static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { + SmallVector sendManyOps; + funcOp.walk([&](spatial::SpatChannelSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); }); for (auto sendManyOp : sendManyOps) { if (sendManyOp.getInputs().empty()) continue; @@ -458,12 +490,17 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { if (!packedInput) continue; - pim::PimSendTensorOp::create(rewriter, sendManyOp.getLoc(), packedInput, sendManyOp.getTargetCoreIdsAttr()); + SmallVector targetCoreIds; + targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size()); + for (int32_t targetCoreId : sendManyOp.getTargetCoreIds()) + targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); + pim::PimSendTensorOp::create( + rewriter, sendManyOp.getLoc(), packedInput, rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.eraseOp(sendManyOp); } - SmallVector concatOps; - funcOp.walk([&](pim::PimConcatOp concatOp) { concatOps.push_back(concatOp); }); + SmallVector concatOps; + funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); }); for (auto concatOp : concatOps) { if (concatOp.getAxis() != 0 || concatOp.getInputs().empty()) continue; @@ -494,11 +531,11 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { unsigned count = endIndex - index; Value packedInput; - if (auto mapOp = dyn_cast(owner)) + if (auto mapOp = dyn_cast(owner)) packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc()); - else if (auto receiveManyOp = dyn_cast(owner)) + else if (auto receiveManyOp = dyn_cast(owner)) packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc()); - else if (auto extractRowsOp = dyn_cast(owner)) + else if (auto extractRowsOp = dyn_cast(owner)) packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc()); if (packedInput) { @@ -516,12 +553,14 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { if (!changed) continue; - auto newConcat = pim::PimConcatOp::create(rewriter, - concatOp.getLoc(), - concatOp.getOutput().getType(), - concatOp.getAxisAttr(), - ValueRange(packedInputs), - concatOp.getOutputBuffer()); + auto newConcat = pim::PimConcatOp::create( + rewriter, + concatOp.getLoc(), + concatOp.getOutput().getType(), + concatOp.getAxisAttr(), + ValueRange(packedInputs), + createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast(concatOp.getOutput().getType())) + .getResult()); rewriter.replaceOp(concatOp, newConcat.getOutput()); } @@ -533,10 +572,9 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { if (op->use_empty()) rewriter.eraseOp(op); }; - eraseUnusedOps(pim::PimMapOp {}); - eraseUnusedOps(pim::PimReceiveManyOp {}); - eraseUnusedOps(pim::PimExtractRowsOp {}); - eraseUnusedOps(pim::PimEmptyManyOp {}); + eraseUnusedOps(spatial::SpatMapOp {}); + eraseUnusedOps(spatial::SpatChannelReceiveManyOp {}); + eraseUnusedOps(spatial::SpatExtractRowsOp {}); } static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, @@ -617,6 +655,7 @@ struct ConcatReturnUseInfo { size_t returnIndex; SmallVector sliceOffsets; SmallVector concatShape; + SmallVector concatChain; SmallVector helperChain; }; @@ -669,6 +708,8 @@ static std::optional analyzeConcatReturnUse(Value value) { auto getConcatResult = [](Operation* op) -> Value { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getResult(); + if (auto spatialConcat = dyn_cast(op)) + return spatialConcat.getOutput(); if (auto pimConcat = dyn_cast(op)) return pimConcat.getOutput(); return {}; @@ -676,6 +717,8 @@ static std::optional analyzeConcatReturnUse(Value value) { auto getConcatAxis = [](Operation* op) -> std::optional { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getDim(); + if (auto spatialConcat = dyn_cast(op)) + return spatialConcat.getAxis(); if (auto pimConcat = dyn_cast(op)) return pimConcat.getAxis(); return std::nullopt; @@ -683,11 +726,14 @@ static std::optional analyzeConcatReturnUse(Value value) { auto getConcatOperands = [](Operation* op) -> OperandRange { if (auto tensorConcat = dyn_cast(op)) return tensorConcat.getOperands(); + if (auto spatialConcat = dyn_cast(op)) + return spatialConcat.getInputs(); return cast(op).getInputs(); }; auto uses = value.getUses(); - if (rangeLength(uses) != 1 || !isa(uses.begin()->getOwner())) + if (rangeLength(uses) != 1 + || !isa(uses.begin()->getOwner())) return std::nullopt; auto valueType = dyn_cast(value.getType()); @@ -696,10 +742,12 @@ static std::optional analyzeConcatReturnUse(Value value) { SmallVector sliceOffsets(valueType.getRank(), 0); SmallVector concatShape(valueType.getShape().begin(), valueType.getShape().end()); + SmallVector concatChain; Value currentValue = value; Operation* currentUser = uses.begin()->getOwner(); - while (isa(currentUser)) { + while (isa(currentUser)) { + concatChain.push_back(currentUser); size_t operandIndex = currentValue.getUses().begin()->getOperandNumber(); int64_t axis = *getConcatAxis(currentUser); for (Value operand : getConcatOperands(currentUser).take_front(operandIndex)) @@ -749,6 +797,7 @@ static std::optional analyzeConcatReturnUse(Value value) { currentValue.getUses().begin()->getOperandNumber(), std::move(sliceOffsets), std::move(concatShape), + std::move(concatChain), std::move(helperChain), }; } @@ -918,11 +967,6 @@ void SpatialToPimPass::runOnOperation() { return; } - SmallVector concatOps; - funcOp.walk([&](spatial::SpatConcatOp op) { concatOps.push_back(op); }); - for (auto concatOp : concatOps) - lowerConcat(concatOp, rewriter); - for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); runOnComputeOp(computeOp, rewriter); @@ -933,6 +977,7 @@ void SpatialToPimPass::runOnOperation() { runOnComputeBatchOp(computeBatchOp, rewriter); } + compactSpatialTensorGroups(funcOp, rewriter); lowerMapOps(funcOp, rewriter); SmallVector receiveOps; @@ -1036,6 +1081,8 @@ void SpatialToPimPass::runOnOperation() { assert(false && "tracked op removal reached a cycle or missed dependency"); } + compactSpatialTensorGroups(funcOp, rewriter); + SmallVector remainingConcatOps; funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); }); for (auto concatOp : remainingConcatOps) @@ -1066,8 +1113,6 @@ void SpatialToPimPass::runOnOperation() { for (auto extractRowsOp : remainingExtractRowsOps) lowerExtractRows(extractRowsOp, rewriter); - compactPimTensorGroups(funcOp, rewriter); - // Dump to file for debug bool hasSpatialOps = false; moduleOp.walk([&](Operation* op) { @@ -1170,6 +1215,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter if (auto concatReturnUse = analyzeConcatReturnUse(result)) { size_t elementSize = yieldType.getElementTypeBitWidth() / 8; + for (Operation* concatOp : concatReturnUse->concatChain) + markOpToRemove(concatOp); if (concatReturnUse->helperChain.empty()) { rewriter.setInsertionPointAfterValue(yieldValue); @@ -1481,13 +1528,15 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { outputTensors.reserve(returnOp->getNumOperands()); for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) { - Operation* returnValueDefiningOp = returnValue.getDefiningOp(); + Value currentReturnValue = returnValue; + Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp(); if (returnValueDefiningOp->hasTrait()) { assert(!hasWeightAlways(returnValueDefiningOp)); - outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; }); + outputTensors.push_back( + [currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; }); } else { - auto outRankedTensorType = llvm::dyn_cast(returnValue.getType()); + auto outRankedTensorType = llvm::dyn_cast(currentReturnValue.getType()); auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType()); std::string outputName = "output_" + std::to_string(index); @@ -1565,7 +1614,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { Operation* onlyUser = *op->getUsers().begin(); isExclusivelyOwnedByReturnChain = - isa(onlyUser) + isa(onlyUser) || isChannelUseChainOp(onlyUser); } if (!isExclusivelyOwnedByReturnChain) @@ -1593,6 +1642,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri return; } + if (auto concatOp = dyn_cast(op)) { + markOpToRemove(concatOp); + for (Value operand : concatOp.getInputs()) + markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); + return; + } + if (auto concatOp = dyn_cast(op)) { markOpToRemove(concatOp); for (Value operand : concatOp.getInputs()) diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index e475256..f351708 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -71,38 +71,6 @@ def PimYieldOp : PimOp<"yield", [Terminator]> { let hasCustomAssemblyFormat = 1; } -def PimMapOp : PimOp<"map", [SingleBlock]> { - let summary = "Apply the same lane-local region to many independent tensors"; - - let arguments = (ins - Variadic:$inputs - ); - - let results = (outs - Variadic:$outputs - ); - - let regions = (region SizedRegion<1>:$body); - - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - -//===----------------------------------------------------------------------===// -// Tensor Utilities -//===----------------------------------------------------------------------===// - -def PimEmptyManyOp : PimOp<"empty_many", []> { - let summary = "Create many identical empty tensors"; - - let results = (outs - Variadic:$outputs - ); - - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// @@ -121,18 +89,6 @@ def PimSendOp : PimOp<"send", []> { }]; } -def PimSendManyOp : PimOp<"send_many", []> { - let summary = "Send multiple tensors to target cores"; - - let arguments = (ins - DenseI32ArrayAttr:$targetCoreIds, - Variadic:$inputs - ); - - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - def PimSendTensorOp : PimOp<"send_tensor", []> { let summary = "Send equal contiguous chunks of one tensor to target cores"; @@ -157,18 +113,6 @@ def PimSendBatchOp : PimOp<"send_batch", []> { let hasCustomAssemblyFormat = 1; } -def PimSendManyBatchOp : PimOp<"send_many_batch", []> { - let summary = "Send multiple per-lane tensors to target cores from a batched core"; - - let arguments = (ins - DenseI32ArrayAttr:$targetCoreIds, - Variadic:$inputs - ); - - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { let summary = "Receive a tensor from another core"; @@ -193,28 +137,6 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { }]; } -def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> { - let summary = "Receive multiple tensors from source cores"; - - let arguments = (ins - Variadic:$outputBuffers, - DenseI32ArrayAttr:$sourceCoreIds - ); - - let results = (outs - Variadic:$outputs - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutputBuffersMutable(); - } - }]; - - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> { let summary = "Receive equal contiguous chunks from source cores into one tensor"; @@ -259,28 +181,6 @@ def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> { let hasCustomAssemblyFormat = 1; } -def PimReceiveManyBatchOp : PimOp<"receive_many_batch", [DestinationStyleOpInterface]> { - let summary = "Receive multiple per-lane tensors from source cores into a batched core"; - - let arguments = (ins - Variadic:$outputBuffers, - DenseI32ArrayAttr:$sourceCoreIds - ); - - let results = (outs - Variadic:$outputs - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutputBuffersMutable(); - } - }]; - - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { let summary = "Copy a memory region from host memory into device memory"; @@ -385,32 +285,6 @@ def PimMemCopyOp : PimOp<"memcp", [DestinationStyleOpInterface]> { }]; } -//===----------------------------------------------------------------------===// -// Tensor utilities -//===----------------------------------------------------------------------===// - -def PimExtractRowsOp : PimOp<"extract_rows", [DestinationStyleOpInterface]> { - let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors"; - - let arguments = (ins - PimTensor:$input, - Variadic:$outputBuffers - ); - - let results = (outs - Variadic:$outputs - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutputBuffersMutable(); - } - }]; - - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - def PimConcatOp : PimOp<"concat", [DestinationStyleOpInterface]> { let summary = "Concatenate tensors"; diff --git a/src/PIM/Dialect/Pim/PimOpsAsm.cpp b/src/PIM/Dialect/Pim/PimOpsAsm.cpp index 7fb95be..8bee1b4 100644 --- a/src/PIM/Dialect/Pim/PimOpsAsm.cpp +++ b/src/PIM/Dialect/Pim/PimOpsAsm.cpp @@ -147,69 +147,6 @@ ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) { return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); } -void PimMapOp::print(OpAsmPrinter& printer) { - printer << " "; - printArgumentBindings(printer, getBody().front(), getInputs()); - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : "; - printer.printType(getInputs().front().getType()); - printer << " -> "; - printer.printType(getOutputs().front().getType()); - printer << " "; - printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); -} - -ParseResult PimMapOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector regionArgs; - SmallVector inputs; - Type inputType; - Type outputType; - - if (parseArgumentBindings(parser, regionArgs, inputs)) - return failure(); - if (inputs.empty()) - return parser.emitError(parser.getCurrentLocation(), "map requires at least one input"); - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) - || parser.parseArrow() || parser.parseType(outputType)) - return failure(); - - SmallVector inputTypes(inputs.size(), inputType); - SmallVector outputTypes(inputs.size(), outputType); - if (regionArgs.size() != inputs.size()) - return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); - if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - - applyArgumentTypes(inputTypes, regionArgs); - Region* body = result.addRegion(); - return parser.parseRegion(*body, regionArgs); -} - -void PimEmptyManyOp::print(OpAsmPrinter& printer) { - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : "; - printer.printType(getOutputs().front().getType()); - printer << " x" << getOutputs().size(); -} - -ParseResult PimEmptyManyOp::parse(OpAsmParser& parser, OperationState& result) { - Type outputType; - int64_t resultCount = 0; - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType) - || parser.parseKeyword("x") || parser.parseInteger(resultCount)) - return failure(); - - if (resultCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "result count after 'x' must be positive"); - - SmallVector resultTypes(resultCount, outputType); - result.addTypes(resultTypes); - return success(); -} - void PimSendBatchOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); @@ -237,36 +174,6 @@ ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { return parser.resolveOperand(input, inputType, result.operands); } -void PimSendManyOp::print(OpAsmPrinter& printer) { - printer << " "; - printCompressedValueSequence(printer, getInputs()); - printCoreIdList(printer, "to", getTargetCoreIds()); - printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, TypeRange(getInputs())); -} - -ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector inputs; - SmallVector inputTypes; - SmallVector targetCoreIds; - - if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds) - || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) - return failure(); - - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds")) - return parser.emitError(parser.getCurrentLocation(), - "targetCoreIds cannot be specified both positionally and in attr-dict"); - if (!targetCoreIds.empty()) - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - - return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); -} - void PimSendTensorOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); @@ -294,72 +201,6 @@ ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) return parser.resolveOperand(input, inputType, result.operands); } -void PimSendManyBatchOp::print(OpAsmPrinter& printer) { - printer << " "; - printCompressedValueSequence(printer, getInputs()); - printCoreIdList(printer, "to", getTargetCoreIds()); - printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, TypeRange(getInputs())); -} - -ParseResult PimSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector inputs; - SmallVector inputTypes; - SmallVector targetCoreIds; - - if (parseCompressedOperandSequence(parser, inputs) || parseOptionalCoreIdList(parser, "to", targetCoreIds) - || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) - return failure(); - - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds")) - return parser.emitError(parser.getCurrentLocation(), - "targetCoreIds cannot be specified both positionally and in attr-dict"); - if (!targetCoreIds.empty()) - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - - return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); -} - -void PimReceiveManyOp::print(OpAsmPrinter& printer) { - printCoreIdList(printer, "from", getSourceCoreIds()); - printer << " into "; - printOpenDelimiter(printer, ListDelimiter::Paren); - printCompressedValueSequence(printer, getOutputBuffers()); - printCloseDelimiter(printer, ListDelimiter::Paren); - printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, getOutputs().getTypes()); -} - -ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector outputBuffers; - SmallVector outputTypes; - SmallVector sourceCoreIds; - - if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen() - || parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen() - || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) - return failure(); - - if (outputBuffers.size() != outputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match"); - if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds")) - return parser.emitError(parser.getCurrentLocation(), - "sourceCoreIds cannot be specified both positionally and in attr-dict"); - if (!sourceCoreIds.empty()) - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - - if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - return success(); -} - void PimReceiveTensorOp::print(OpAsmPrinter& printer) { printCoreIdList(printer, "from", getSourceCoreIds()); printer << " into "; @@ -434,77 +275,6 @@ ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result return success(); } -void PimReceiveManyBatchOp::print(OpAsmPrinter& printer) { - printCoreIdList(printer, "from", getSourceCoreIds()); - printer << " into "; - printOpenDelimiter(printer, ListDelimiter::Paren); - printCompressedValueSequence(printer, getOutputBuffers()); - printCloseDelimiter(printer, ListDelimiter::Paren); - printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, getOutputs().getTypes()); -} - -ParseResult PimReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector outputBuffers; - SmallVector outputTypes; - SmallVector sourceCoreIds; - - if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen() - || parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen() - || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) - return failure(); - - if (outputBuffers.size() != outputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match"); - if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds")) - return parser.emitError(parser.getCurrentLocation(), - "sourceCoreIds cannot be specified both positionally and in attr-dict"); - if (!sourceCoreIds.empty()) - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - - if (parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - return success(); -} - -void PimExtractRowsOp::print(OpAsmPrinter& printer) { - printer << " "; - printer.printOperand(getInput()); - printer << " into "; - printOpenDelimiter(printer, ListDelimiter::Paren); - printCompressedValueSequence(printer, getOutputBuffers()); - printCloseDelimiter(printer, ListDelimiter::Paren); - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : "; - printer.printType(getInput().getType()); - printer << " -> "; - printCompressedTypeSequence(printer, getOutputs().getTypes()); -} - -ParseResult PimExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) { - OpAsmParser::UnresolvedOperand input; - SmallVector outputBuffers; - Type inputType; - SmallVector outputTypes; - - if (parser.parseOperand(input) || parser.parseKeyword("into") || parser.parseLParen() - || parseCompressedOperandSequence(parser, outputBuffers) || parser.parseRParen() - || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) - || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) - return failure(); - - if (outputBuffers.size() != outputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of output buffers and output types must match"); - if (parser.resolveOperand(input, inputType, result.operands) - || parser.resolveOperands(outputBuffers, outputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - return success(); -} - void PimConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis() << " "; printCompressedValueSequence(printer, getInputs()); diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp index cae01e5..3d39f4c 100644 --- a/src/PIM/Dialect/Pim/PimOpsVerify.cpp +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -13,12 +13,6 @@ namespace pim { namespace { -static LogicalResult verifyManyCommunicationSizes(Operation* op, ArrayRef coreIds, size_t valueCount) { - if (coreIds.size() != valueCount) - return op->emitError("core id metadata length must match the number of values"); - return success(); -} - static bool haveSameShapedContainerKind(Type lhs, Type rhs) { return (isa(lhs) && isa(rhs)) || (isa(lhs) && isa(rhs)); } @@ -33,28 +27,6 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r return success(); } -static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types, StringRef kind) { - if (types.empty()) - return op->emitError() << kind << " must carry at least one value"; - - Type firstType = types.front(); - auto firstShapedType = dyn_cast(firstType); - bool firstIsTensor = isa(firstType); - bool firstIsMemRef = isa(firstType); - for (Type type : types.drop_front()) - if (type != firstType) { - auto shapedType = dyn_cast(type); - if (!firstShapedType || !shapedType) - return op->emitError() << kind << " values must all have the same type"; - if (firstIsTensor != isa(type) || firstIsMemRef != isa(type)) - return op->emitError() << kind << " values must all use the same shaped container kind"; - if (firstShapedType.getElementType() != shapedType.getElementType() - || firstShapedType.getShape() != shapedType.getShape()) - return op->emitError() << kind << " values must all have the same shape and element type"; - } - return success(); -} - static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef coreIds, StringRef kind) { if (coreIds.empty()) return op->emitError() << kind << " must carry at least one chunk"; @@ -74,109 +46,12 @@ static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRe return success(); } -static FailureOr getParentBatchLaneCount(Operation* op) { - auto coreBatchOp = op->getParentOfType(); - if (!coreBatchOp) - return failure(); - return coreBatchOp.getLaneCount(); -} - -static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef coreIds, size_t valueCount) { - auto laneCount = getParentBatchLaneCount(op); - if (failed(laneCount)) - return op->emitError("must be nested inside pim.core_batch"); - if (coreIds.size() != valueCount * static_cast(*laneCount)) - return op->emitError("core id metadata length must match the number of values times parent laneCount"); - return success(); -} - } // namespace -LogicalResult PimEmptyManyOp::verify() { - if (getOutputs().empty()) - return emitError("must produce at least one output"); - - Type firstType = getOutputs().front().getType(); - auto firstShapedType = dyn_cast(firstType); - if (!firstShapedType || !firstShapedType.hasRank()) - return emitError("outputs must all be ranked shaped types"); - - for (Value output : getOutputs().drop_front()) - if (output.getType() != firstType) - return emitError("outputs must all have the same type"); - - return success(); -} - -LogicalResult PimMapOp::verify() { - if (getInputs().empty()) - return emitError("requires at least one input"); - if (getOutputs().size() != getInputs().size()) - return emitError("number of outputs must match number of inputs"); - - Type inputType = getInputs().front().getType(); - for (Value input : getInputs().drop_front()) - if (input.getType() != inputType) - return emitError("all inputs must have the same type"); - - Type outputType = getOutputs().front().getType(); - for (Value output : getOutputs().drop_front()) - if (output.getType() != outputType) - return emitError("all outputs must have the same type"); - - Block& block = getBody().front(); - if (block.getNumArguments() != 1) - return emitError("body must have exactly one block argument"); - if (failed(verifyCompatibleShapedTypes( - getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type"))) - return emitError("body block argument type must match input type"); - - auto yieldOp = dyn_cast_or_null(block.getTerminator()); - if (!yieldOp) - return emitError("body must terminate with pim.yield"); - if (yieldOp.getNumOperands() != 1) - return emitError("body yield must produce exactly one value"); - if (failed(verifyCompatibleShapedTypes( - getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type"))) - return emitError("body yield type must match output type"); - - return success(); -} - -LogicalResult PimSendManyOp::verify() { - if (failed(verifyManyCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size()))) - return failure(); - return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many"); -} - LogicalResult PimSendTensorOp::verify() { return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor"); } -LogicalResult PimSendManyBatchOp::verify() { - if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size()))) - return failure(); - return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many_batch"); -} - -LogicalResult PimReceiveManyOp::verify() { - if (getOutputBuffers().size() != getOutputs().size()) - return emitError("number of output buffers must match the number of outputs"); - if (failed(verifyManyCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size()))) - return failure(); - - if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many"))) - return failure(); - if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many"))) - return failure(); - - for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) - if (outputBuffer.getType() != output.getType()) - return emitError("output buffers and outputs must have matching types"); - - return success(); -} - LogicalResult PimReceiveTensorOp::verify() { if (failed(verifyCompatibleShapedTypes( getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) @@ -185,61 +60,6 @@ LogicalResult PimReceiveTensorOp::verify() { return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor"); } -LogicalResult PimReceiveManyBatchOp::verify() { - if (getOutputBuffers().size() != getOutputs().size()) - return emitError("number of output buffers must match the number of outputs"); - if (failed(verifyManyBatchCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size()))) - return failure(); - - if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many_batch"))) - return failure(); - if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many_batch"))) - return failure(); - - for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) - if (outputBuffer.getType() != output.getType()) - return emitError("output buffers and outputs must have matching types"); - - return success(); -} - -LogicalResult PimExtractRowsOp::verify() { - if (getOutputBuffers().size() != getOutputs().size()) - return emitError("number of output buffers must match the number of outputs"); - - auto inputType = dyn_cast(getInput().getType()); - if (!inputType || !inputType.hasRank() || inputType.getRank() != 2) - return emitError("input must be a rank-2 shaped type"); - - int64_t numRows = inputType.getShape()[0]; - int64_t numCols = inputType.getShape()[1]; - Type elementType = inputType.getElementType(); - - if (numRows >= 0 && static_cast(getOutputs().size()) != numRows) - return emitError("number of outputs must match the number of input rows"); - - for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) { - if (failed(verifyCompatibleShapedTypes( - getOperation(), outputBuffer.getType(), output.getType(), "output buffers and outputs must match"))) - return failure(); - - auto outputType = dyn_cast(output.getType()); - if (!outputType || !outputType.hasRank() || outputType.getRank() != 2) - return emitError("outputs must all be rank-2 shaped types"); - if (!haveSameShapedContainerKind(getInput().getType(), output.getType())) - return emitError("outputs must use the same shaped container kind as the input"); - if (outputType.getElementType() != elementType) - return emitError("output element types must match input element type"); - auto outputShape = outputType.getShape(); - if (outputShape[0] != 1) - return emitError("each output must have exactly one row"); - if (numCols >= 0 && outputShape[1] != numCols) - return emitError("output column count must match input column count"); - } - - return success(); -} - LogicalResult PimConcatOp::verify() { if (getInputs().empty()) return emitError("requires at least one input"); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 9f24c82..d07bd9d 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -180,39 +180,6 @@ struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel { - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return !cast(op).isDpsInit(&opOperand); - } - - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto receiveOp = cast(op); - SmallVector outputBuffers; - SmallVector resultTypes; - outputBuffers.reserve(receiveOp.getOutputBuffers().size()); - resultTypes.reserve(receiveOp.getOutputBuffers().size()); - - for (Value outputBuffer : receiveOp.getOutputBuffers()) { - auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state); - if (failed(outputBufferOpt)) - return failure(); - outputBuffers.push_back(*outputBufferOpt); - resultTypes.push_back(outputBufferOpt->getType()); - } - - auto newOp = PimReceiveManyOp::create(rewriter, - receiveOp.getLoc(), - TypeRange(resultTypes), - ValueRange(outputBuffers), - receiveOp.getSourceCoreIdsAttr()); - rewriter.replaceOp(receiveOp, newOp.getOutputs()); - return success(); - } -}; - struct ReceiveTensorOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { @@ -234,77 +201,6 @@ struct ReceiveTensorOpInterface } }; -struct ReceiveManyBatchOpInterface -: DstBufferizableOpInterfaceExternalModel { - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return !cast(op).isDpsInit(&opOperand); - } - - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto receiveOp = cast(op); - SmallVector outputBuffers; - SmallVector resultTypes; - outputBuffers.reserve(receiveOp.getOutputBuffers().size()); - resultTypes.reserve(receiveOp.getOutputBuffers().size()); - - for (Value outputBuffer : receiveOp.getOutputBuffers()) { - auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state); - if (failed(outputBufferOpt)) - return failure(); - outputBuffers.push_back(*outputBufferOpt); - resultTypes.push_back(outputBufferOpt->getType()); - } - - auto newOp = PimReceiveManyBatchOp::create(rewriter, - receiveOp.getLoc(), - TypeRange(resultTypes), - ValueRange(outputBuffers), - receiveOp.getSourceCoreIdsAttr()); - rewriter.replaceOp(receiveOp, newOp.getOutputs()); - return success(); - } -}; - -struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel { - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return !cast(op).isDpsInit(&opOperand); - } - - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto extractRowsOp = cast(op); - auto inputOpt = getBufferOrValue(rewriter, extractRowsOp.getInput(), options, state); - if (failed(inputOpt)) - return failure(); - - SmallVector outputBuffers; - SmallVector resultTypes; - outputBuffers.reserve(extractRowsOp.getOutputBuffers().size()); - resultTypes.reserve(extractRowsOp.getOutputBuffers().size()); - - for (Value outputBuffer : extractRowsOp.getOutputBuffers()) { - auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state); - if (failed(outputBufferOpt)) - return failure(); - outputBuffers.push_back(*outputBufferOpt); - resultTypes.push_back(outputBufferOpt->getType()); - } - - auto newOp = PimExtractRowsOp::create(rewriter, - extractRowsOp.getLoc(), - TypeRange(resultTypes), - materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), - ValueRange(outputBuffers)); - rewriter.replaceOp(extractRowsOp, newOp.getOutputs()); - return success(); - } -}; - struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); @@ -334,31 +230,6 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel { - bool bufferizesToAllocation(Operation* op, Value value) const { return true; } - - bool resultBufferizesToMemoryWrite(Operation* op, OpResult opResult, const AnalysisState& state) const { - return false; - } - - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto emptyManyOp = cast(op); - - SmallVector resultTypes; - resultTypes.reserve(emptyManyOp.getOutputs().size()); - for (Value output : emptyManyOp.getOutputs()) { - auto shapedType = cast(output.getType()); - resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType())); - } - - replaceOpWithNewBufferizedOp(rewriter, emptyManyOp, TypeRange(resultTypes)); - return success(); - } -}; - struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } @@ -383,7 +254,7 @@ struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel { +struct SendOpInterface : BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } @@ -392,75 +263,93 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel(op); - auto bbArg = dyn_cast(value); - if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 - || mapOp.getInputs().empty()) - return {}; - - return { - {&mapOp->getOpOperand(0), BufferRelation::Equivalent} - }; - } - - bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; } - - FailureOr getBufferType(Operation* op, - Value value, - const BufferizationOptions& options, - const BufferizationState& state, - SmallVector& invocationStack) const { - auto mapOp = cast(op); - auto bbArg = dyn_cast(value); - if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 - || mapOp.getInputs().empty()) + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto sendOp = cast(op); + auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state); + if (failed(inputOpt)) return failure(); - auto inputType = dyn_cast(mapOp.getInputs().front().getType()); - if (inputType) - return inputType; + replaceOpWithNewBufferizedOp(rewriter, + op, + materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), + sendOp.getSizeAttr(), + sendOp.getTargetCoreIdAttr()); + return success(); + } +}; - auto shapedType = cast(mapOp.getInputs().front().getType()); - return BufferLikeType(MemRefType::get(shapedType.getShape(), shapedType.getElementType())); +struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } + + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; } LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState& state) const { - auto mapOp = cast(op); + auto sendOp = cast(op); + auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state); + if (failed(inputOpt)) + return failure(); - SmallVector inputs; - SmallVector resultTypes; - inputs.reserve(mapOp.getInputs().size()); - resultTypes.reserve(mapOp.getOutputs().size()); + replaceOpWithNewBufferizedOp(rewriter, + op, + materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), + sendOp.getSizeAttr(), + sendOp.getTargetCoreIdsAttr()); + return success(); + } +}; - for (Value input : mapOp.getInputs()) { - if (isa(input.getType())) { - auto inputOpt = getBufferOrValue(rewriter, input, options, state); - if (failed(inputOpt)) +struct CoreOpInterface : BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } + + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto coreOp = cast(op); + + bool alreadyBufferized = + llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa(weight.getType()); }); + if (alreadyBufferized) + return success(); + + SmallVector weights; + weights.reserve(coreOp.getWeights().size()); + for (Value weight : coreOp.getWeights()) { + if (isa(weight.getType())) { + auto weightOpt = getBufferOrValue(rewriter, weight, options, state); + if (failed(weightOpt)) return failure(); - inputs.push_back(*inputOpt); + weights.push_back(*weightOpt); } else { - inputs.push_back(input); + weights.push_back(weight); } } - for (Value output : mapOp.getOutputs()) { - auto shapedType = cast(output.getType()); - resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType())); - } - - rewriter.setInsertionPoint(mapOp); - auto newOp = PimMapOp::create(rewriter, mapOp.getLoc(), TypeRange(resultTypes), ValueRange(inputs)); - rewriter.inlineRegionBefore(mapOp.getBody(), newOp.getBody(), newOp.getBody().begin()); + rewriter.setInsertionPoint(coreOp); + auto newOp = PimCoreOp::create(rewriter, coreOp.getLoc(), ValueRange(weights), coreOp.getCoreIdAttr()); + rewriter.inlineRegionBefore(coreOp.getBody(), newOp.getBody(), newOp.getBody().begin()); for (Block& block : newOp.getBody()) if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state))) return failure(); - rewriter.replaceOp(mapOp, newOp.getOutputs()); + rewriter.eraseOp(coreOp); return success(); } }; @@ -730,16 +619,14 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel(*ctx); - PimMapOp::attachInterface(*ctx); + PimCoreOp::attachInterface(*ctx); PimCoreBatchOp::attachInterface(*ctx); PimReceiveOp::attachInterface(*ctx); - PimReceiveManyOp::attachInterface(*ctx); PimReceiveTensorOp::attachInterface(*ctx); PimReceiveBatchOp::attachInterface(*ctx); - PimReceiveManyBatchOp::attachInterface(*ctx); + PimSendOp::attachInterface(*ctx); + PimSendBatchOp::attachInterface(*ctx); PimSendTensorOp::attachInterface(*ctx); - PimExtractRowsOp::attachInterface(*ctx); PimConcatOp::attachInterface(*ctx); PimMemCopyHostToDevOp::attachInterface(*ctx); PimMemCopyHostToDevBatchOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 99ed9ab..50d3b6e 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -1,10 +1,9 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Threading.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -47,53 +46,18 @@ private: void PimBufferizationPass::runOnOperation() { auto moduleOp = getOperation(); - // Refactor this into a function - { - auto funcOp = *getPimEntryFunc(moduleOp); + auto funcOp = *getPimEntryFunc(moduleOp); - SmallVector coreOps; - funcOp->walk([&](Operation* op) { - if (isa(op)) - coreOps.push_back(op); - }); - MLIRContext* ctx = moduleOp.getContext(); - // failableParallelForEach will run the lambda in parallel and stop if any thread fails - LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](Operation* coreOp) { - // Again, allocate state LOCALLY per thread/function - bufferization::OneShotBufferizationOptions options; - options.allowUnknownOps = true; - if (isa(coreOp)) - options.opFilter.denyOperation([coreOp](Operation* op) { return op == coreOp; }); - bufferization::BufferizationState state; - if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) { - coreOp->emitError("Failed to bufferize PIM and Spatial ops"); - return failure(); - } - return success(); - }); + bufferization::OneShotBufferizationOptions options; + options.allowUnknownOps = true; + options.bufferizeFunctionBoundaries = true; + options.setFunctionBoundaryTypeConversion(bufferization::LayoutMapOption::IdentityLayoutMap); + bufferization::BufferizationState state; - if (failed(result)) { - moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops"); - signalPassFailure(); - } - - funcOp->walk([&](bufferization::ToTensorOp toTensorOp) { - if (llvm::isa_and_present(toTensorOp->getParentOp())) - toTensorOp->setAttr("restrict", UnitAttr::get(ctx)); - }); - - // One-Shot-Bufferization - bufferization::OneShotBufferizationOptions options; - options.allowUnknownOps = true; - options.opFilter.denyOperation([](Operation* op) { - return op->getParentOfType() || op->getParentOfType(); - }); - bufferization::BufferizationState state; - - if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { - moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); - signalPassFailure(); - } + if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options, state))) { + moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); + signalPassFailure(); + return; } MLIRContext* ctx = moduleOp.getContext(); @@ -119,30 +83,6 @@ void PimBufferizationPass::runOnOperation() { return; } - // Remove toTensor operations: leave memrefs instead - moduleOp.walk([](bufferization::ToTensorOp toTensorOp) { - toTensorOp.replaceAllUsesWith(toTensorOp.getBuffer()); - toTensorOp.erase(); - }); - - // Change main function return types from tensors to memrefs - func::FuncOp funcOp; - for (Operation& op : moduleOp.getBody()->getOperations()) - if ((funcOp = dyn_cast(&op))) - break; - auto oldFuncType = funcOp.getFunctionType(); - SmallVector newResults; - bool changed = false; - for (Type type : oldFuncType.getResults()) - if (auto tensorType = dyn_cast(type)) { - newResults.push_back(MemRefType::get(tensorType.getShape(), tensorType.getElementType())); - changed = true; - } - else - newResults.push_back(type); - if (changed) - funcOp.setType(FunctionType::get(funcOp.getContext(), oldFuncType.getInputs(), newResults)); - annotateWeightsMemrefs(moduleOp, funcOp); // Dump to file for debug diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index ca39577..31b72d5 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -1,13 +1,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/MathExtras.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" @@ -34,35 +33,6 @@ static int64_t getValueSizeInBytes(Value value) { return type.getNumElements() * type.getElementTypeBitWidth() / 8; } -static void expandPimMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { - SmallVector mapOps; - funcOp.walk([&](pim::PimMapOp mapOp) { mapOps.push_back(mapOp); }); - - for (auto mapOp : mapOps) { - Block& body = mapOp.getBody().front(); - auto yieldOp = cast(body.getTerminator()); - - SmallVector replacements; - replacements.reserve(mapOp.getInputs().size()); - rewriter.setInsertionPoint(mapOp); - for (Value input : mapOp.getInputs()) { - IRMapping mapping; - mapping.map(body.getArgument(0), input); - - for (Operation& op : body.without_terminator()) { - Operation* cloned = rewriter.clone(op, mapping); - for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapping.map(originalResult, clonedResult); - rewriter.setInsertionPointAfter(cloned); - } - - replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0))); - } - - rewriter.replaceOp(mapOp, replacements); - } -} - struct MaterializeHostConstantsPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass) @@ -80,8 +50,6 @@ struct MaterializeHostConstantsPass : PassWrapper()) { DenseMap>> materializedValues; @@ -150,38 +118,11 @@ struct MaterializeHostConstantsPass : PassWrapper hostCompactOps; for (Operation& op : funcOp.getBody().front()) - if (isa(op)) + if (isa(op)) hostCompactOps.push_back(&op); for (Operation* op : hostCompactOps) { rewriter.setInsertionPoint(op); - - if (auto extractRowsOp = dyn_cast(op)) { - auto inputType = dyn_cast(extractRowsOp.getInput().getType()); - if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 2) { - extractRowsOp.emitOpError("host-side extract_rows lowering requires a static rank-2 input"); - hasFailure = true; - continue; - } - - int64_t numCols = inputType.getDimSize(1); - SmallVector replacementRows; - replacementRows.reserve(extractRowsOp.getOutputs().size()); - for (auto rowIndex : llvm::seq(0, extractRowsOp.getOutputs().size())) { - SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), - rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - replacementRows.push_back(memref::SubViewOp::create( - rewriter, extractRowsOp.getLoc(), extractRowsOp.getInput(), offsets, sizes, strides) - .getResult()); - } - - extractRowsOp->replaceAllUsesWith(ValueRange(replacementRows)); - extractRowsOp->erase(); - continue; - } - auto concatOp = cast(op); concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization"); hasFailure = true; diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 0514dbb..81b9de9 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -18,7 +18,6 @@ namespace { static bool isAddressOnlyHostOp(Operation* op) { return isa(defOp)) + if (isa(defOp)) return true; - if (auto subview = dyn_cast(defOp)) { value = subview.getSource(); continue; } - if (auto cast = dyn_cast(defOp)) { value = cast.getSource(); continue; } - if (auto collapse = dyn_cast(defOp)) { value = collapse.getSrc(); continue; } - if (auto expand = dyn_cast(defOp)) { value = expand.getSrc(); continue; } + if (auto subview = dyn_cast(defOp)) { + value = subview.getSource(); + continue; + } + if (auto cast = dyn_cast(defOp)) { + value = cast.getSource(); + continue; + } + if (auto collapse = dyn_cast(defOp)) { + value = collapse.getSrc(); + continue; + } + if (auto expand = dyn_cast(defOp)) { + value = expand.getSrc(); + continue; + } return false; } } @@ -52,7 +63,38 @@ static bool isCodegenAddressableValue(Value value) { if (failed(resolvedAddress)) return false; return isa(resolvedAddress->base) - || isa(resolvedAddress->base.getDefiningOp()); + || isa(resolvedAddress->base.getDefiningOp()); +} + +static bool isConstantGlobalView(Value value) { + auto allStaticSubviewParts = [](memref::SubViewOp subview) { + return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) + && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) + && llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); }); + }; + + while (true) { + Operation* defOp = value.getDefiningOp(); + if (!defOp) + return false; + if (auto getGlobalOp = dyn_cast(defOp)) { + auto moduleOp = getGlobalOp->getParentOfType(); + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + return globalOp && globalOp.getConstant() && globalOp.getInitialValue() + && isa(*globalOp.getInitialValue()); + } + if (auto subview = dyn_cast(defOp)) { + if (!allStaticSubviewParts(subview)) + return false; + value = subview.getSource(); + continue; + } + if (auto cast = dyn_cast(defOp)) { + value = cast.getSource(); + continue; + } + return false; + } } static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { @@ -125,13 +167,17 @@ private: bool hasFailure = false; for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) { auto getGlobalOp = weight.template getDefiningOp(); - if (!getGlobalOp) { + if (!getGlobalOp && !isConstantGlobalView(weight)) { coreOp.emitOpError() << "weight #" << weightIndex - << " must be materialized as memref.get_global before JSON codegen"; + << " must be materialized as a constant memref.global or a static view of one before JSON " + "codegen"; hasFailure = true; continue; } + if (!getGlobalOp) + continue; + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp) { coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global"; @@ -185,7 +231,7 @@ private: continue; } - if (!isa(resolvedAddress->base.getDefiningOp())) { + if (!isa(resolvedAddress->base.getDefiningOp())) { op.emitOpError() << "operand #" << operandIndex << " must be backed by device-local memory; materialize host values with pim.memcp_hd"; hasFailure = true; @@ -197,7 +243,7 @@ private: static LogicalResult verifyAddressOnlyHostOp(Operation* op) { if (auto subviewOp = dyn_cast(op)) - return verifyAddressOnlySource(op, subviewOp.getSource()); + return verifyAddressOnlyBase(op, subviewOp.getSource()); if (auto castOp = dyn_cast(op)) return verifyAddressOnlySource(op, castOp.getSource()); if (auto collapseOp = dyn_cast(op)) @@ -221,6 +267,14 @@ private: op->emitOpError("depends on a value that is not backed by contiguous addressable storage"); return failure(); } + + static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) { + if (isBaseAddressableValue(source)) + return success(); + + op->emitOpError("depends on a value that is not backed by addressable storage"); + return failure(); + } }; } // namespace diff --git a/src/PIM/TODO.md b/src/PIM/TODO.md deleted file mode 100644 index 267cd4e..0000000 --- a/src/PIM/TODO.md +++ /dev/null @@ -1,10 +0,0 @@ -Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode) - -AnalisiDCP -NuovoPasso che inserische le send e recive e gestisce gli input -Passo che fa il merge - -Probabilmente questo rompera' gli input e come venivano gestiti prima. - - -