From f054e66ed01b71eb9f198e6fea4c7c7c8bb2fe98 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 10 Apr 2026 18:50:25 +0200 Subject: [PATCH] reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions --- src/PIM/CMakeLists.txt | 2 + src/PIM/Common/PimCommon.cpp | 267 +++++++++++++++++- src/PIM/Common/PimCommon.hpp | 32 +++ src/PIM/Compiler/PimCodeGen.cpp | 215 +++++++------- src/PIM/Compiler/PimCodeGen.hpp | 41 +-- .../Conversion/ONNXToSpatial/CMakeLists.txt | 1 + .../ONNXToSpatial/ONNXToSpatialPass.cpp | 7 +- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 89 +++--- .../Conversion/SpatialToPim/CMakeLists.txt | 1 + .../SpatialToPim/SpatialToPimPass.cpp | 8 +- src/PIM/Pass/CMakeLists.txt | 1 + src/PIM/Pass/Pim/ConstantFolding/Common.cpp | 26 +- src/PIM/Pass/Pim/ConstantFolding/Common.hpp | 7 +- .../Pim/ConstantFolding/Patterns/Constant.cpp | 15 +- .../Pim/ConstantFolding/Patterns/Subview.cpp | 92 ++++-- src/PIM/Pass/Pim/VerificationPass.cpp | 51 ++-- src/PIM/PimAccelerator.cpp | 4 + validation/.gitignore | 5 + 18 files changed, 623 insertions(+), 241 deletions(-) diff --git a/src/PIM/CMakeLists.txt b/src/PIM/CMakeLists.txt index 3ed0853..4531b9e 100644 --- a/src/PIM/CMakeLists.txt +++ b/src/PIM/CMakeLists.txt @@ -54,6 +54,8 @@ add_pim_library(OMPIMAccel ${PIM_PUBLIC_INCLUDE_DIRS} LINK_LIBS PUBLIC + MLIRSCFDialect + MLIRSCFTransforms onnx OMAccelerator OMPimCompilerUtils diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index ebfd18e..0f72d1f 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -1,4 +1,6 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" @@ -8,6 +10,7 @@ #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -240,8 +243,129 @@ bool isMemoryContiguous(ArrayRef srcShape, return true; } -FailureOr resolveContiguousAddress(Value value) { +static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) { + if (!knowledge) + return value; + + auto iter = knowledge->aliases.find(value); + while (iter != knowledge->aliases.end()) { + value = iter->second; + iter = knowledge->aliases.find(value); + } + return value; +} + +// Walks through view-like ops and DPS tied operands to find the "underlying" memref value +// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop +// and when propagating yielded values across iterations during static unrolling. +static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) { + value = resolveAlias(value, knowledge); + + if (auto blockArgument = dyn_cast(value)) + return value; + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return value; + + if (auto dpsDefiningOp = dyn_cast(definingOp)) { + if (auto result = dyn_cast(value)) + if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result)) + return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge); + } + + if (auto castOp = dyn_cast(definingOp)) + return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge); + if (auto collapseOp = dyn_cast(definingOp)) + return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge); + if (auto expandOp = dyn_cast(definingOp)) + return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge); + + return value; +} + +static FailureOr resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge); + +static FailureOr resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) { + value = resolveAlias(value, knowledge); + + if (knowledge) { + auto iter = knowledge->indexValues.find(value); + if (iter != knowledge->indexValues.end()) + return iter->second; + } + + auto constantOp = value.getDefiningOp(); + if (constantOp) { + if (auto integerAttr = dyn_cast(constantOp.getValue())) + return integerAttr.getInt(); + } + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return failure(); + + if (auto indexCastOp = dyn_cast(definingOp)) + return resolveIndexValueImpl(indexCastOp.getIn(), knowledge); + + if (auto addOp = dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs)) + return failure(); + return *lhs + *rhs; + } + + if (auto subOp = dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs)) + return failure(); + return *lhs - *rhs; + } + + if (auto mulOp = dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs)) + return failure(); + return *lhs * *rhs; + } + + if (auto divOp = dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs) || *rhs == 0) + return failure(); + return static_cast(static_cast(*lhs) / static_cast(*rhs)); + } + + if (auto remOp = dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs) || *rhs == 0) + return failure(); + return static_cast(static_cast(*lhs) % static_cast(*rhs)); + } + + return failure(); +} + +static FailureOr resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) { + if (auto attr = dyn_cast(ofr)) { + auto integerAttr = dyn_cast(attr); + if (!integerAttr) + return failure(); + return integerAttr.getInt(); + } + + return resolveIndexValueImpl(cast(ofr), knowledge); +} + +static FailureOr resolveContiguousAddressImpl(Value value, + const StaticValueKnowledge* knowledge) { int64_t byteOffset = 0; + value = resolveAlias(value, knowledge); while (true) { if (isa(value)) @@ -255,7 +379,29 @@ FailureOr resolveContiguousAddress(Value value) { OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast(value)); if (!tiedOperand) return failure(); - value = tiedOperand->get(); + value = resolveAlias(tiedOperand->get(), knowledge); + continue; + } + + if (auto forOp = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return failure(); + + // Trace the loop carry back to its underlying memref, then if that memref is the + // loop's own iter-arg we know the base comes from the corresponding init arg + // (every iteration yields the same backing memory in the DPS sense). + auto yieldOp = cast(forOp.getBody()->getTerminator()); + Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); + if (auto blockArgument = dyn_cast(yieldedValue)) { + if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 + && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { + value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); + continue; + } + } + + value = yieldedValue; continue; } @@ -265,31 +411,53 @@ FailureOr resolveContiguousAddress(Value value) { if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) return failure(); - ArrayRef offsets = subviewOp.getStaticOffsets(); - ArrayRef sizes = subviewOp.getStaticSizes(); - ArrayRef strides = subviewOp.getStaticStrides(); - if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic) - || llvm::is_contained(strides, ShapedType::kDynamic)) - return failure(); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(subviewOp.getMixedOffsets().size()); + sizes.reserve(subviewOp.getMixedSizes().size()); + strides.reserve(subviewOp.getMixedStrides().size()); + + for (OpFoldResult offset : subviewOp.getMixedOffsets()) { + auto resolvedOffset = resolveOpFoldResult(offset, knowledge); + if (failed(resolvedOffset)) + return failure(); + offsets.push_back(*resolvedOffset); + } + + for (OpFoldResult size : subviewOp.getMixedSizes()) { + auto resolvedSize = resolveOpFoldResult(size, knowledge); + if (failed(resolvedSize)) + return failure(); + sizes.push_back(*resolvedSize); + } + + for (OpFoldResult stride : subviewOp.getMixedStrides()) { + auto resolvedStride = resolveOpFoldResult(stride, knowledge); + if (failed(resolvedStride)) + return failure(); + strides.push_back(*resolvedStride); + } + if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) return failure(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; - value = subviewOp.getSource(); + value = resolveAlias(subviewOp.getSource(), knowledge); continue; } if (auto castOp = dyn_cast(definingOp)) { - value = castOp.getSource(); + value = resolveAlias(castOp.getSource(), knowledge); continue; } if (auto collapseOp = dyn_cast(definingOp)) { - value = collapseOp.getSrc(); + value = resolveAlias(collapseOp.getSrc(), knowledge); continue; } if (auto expandOp = dyn_cast(definingOp)) { - value = expandOp.getSrc(); + value = resolveAlias(expandOp.getSrc(), knowledge); continue; } @@ -300,4 +468,79 @@ FailureOr resolveContiguousAddress(Value value) { } } +FailureOr resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); } + +FailureOr resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) { + return resolveIndexValueImpl(value, &knowledge); +} + +FailureOr resolveContiguousAddress(Value value) { + return resolveContiguousAddressImpl(value, nullptr); +} + +FailureOr resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) { + return resolveContiguousAddressImpl(value, &knowledge); +} + +Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) { + return resolveLoopCarriedAliasImpl(value, &knowledge); +} + +bool isCoreStaticAddressOp(Operation* op) { + return isa(op); +} + +LogicalResult walkPimCoreBlock(Block& block, + const StaticValueKnowledge& knowledge, + llvm::function_ref callback) { + bool hasFailure = false; + for (Operation& op : block) { + if (isa(op) || isCoreStaticAddressOp(&op)) + continue; + + if (auto forOp = dyn_cast(op)) { + Block& loopBody = forOp.getRegion().front(); + auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge); + auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge); + auto step = resolveIndexValue(forOp.getStep(), knowledge); + if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) { + forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen"); + hasFailure = true; + continue; + } + + SmallVector iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end()); + for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) { + StaticValueKnowledge loopKnowledge = knowledge; + loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue; + for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues)) + loopKnowledge.aliases[iterArg] = iterValue; + + if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback))) + hasFailure = true; + + auto yieldOp = cast(loopBody.getTerminator()); + for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands())) + iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge); + } + continue; + } + + if (failed(callback(op, knowledge))) + hasFailure = true; + } + return success(!hasFailure); +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index f31492e..99cec4d 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -6,6 +6,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -20,6 +21,13 @@ struct ResolvedContiguousAddress { int64_t byteOffset = 0; }; +struct StaticValueKnowledge { + llvm::DenseMap indexValues; + llvm::DenseMap aliases; + + StaticValueKnowledge() {} +}; + std::string getOutputDir(); void createDirectory(const std::string& directory); @@ -52,5 +60,29 @@ bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef strides); llvm::FailureOr resolveContiguousAddress(mlir::Value value); +llvm::FailureOr resolveContiguousAddress(mlir::Value value, + const StaticValueKnowledge& knowledge); + +llvm::FailureOr resolveIndexValue(mlir::Value value); +llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge); + +/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for +/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries. +mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge); + +/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and +/// only contribute to static addressing or index computations (arith integer math, +/// memref view ops, memref.alloc, arith.constant). +bool isCoreStaticAddressOp(mlir::Operation* op); + +/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically +/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op +/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is +/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback +/// failure so callers can collect multiple diagnostics, but propagates the overall result. +mlir::LogicalResult +walkPimCoreBlock(mlir::Block& block, + const StaticValueKnowledge& knowledge, + llvm::function_ref callback); } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index c48de47..7006869 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -84,8 +84,8 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { return deviceMem.try_emplace(id, memEntriesMap).first->second; } -size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const { - auto resolvedAddress = resolveContiguousAddress(value); +size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge) const { + auto resolvedAddress = resolveContiguousAddress(value, knowledge); if (failed(resolvedAddress)) { errs() << "Failed to resolve contiguous address for value: "; value.print(errs()); @@ -199,47 +199,49 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_ emitInstruction(std::move(json)); } -void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const { +void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const { emitMemCopyOp("ld", - memory.getValueAddress(loadOp.getDeviceTarget()), + addressOf(loadOp.getDeviceTarget(), knowledge), loadOp.getDeviceTargetOffset(), - memory.getValueAddress(loadOp.getHostSource()), + addressOf(loadOp.getHostSource(), knowledge), loadOp.getHostSourceOffset(), loadOp.getSize()); } -void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const { +void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const { emitMemCopyOp("st", - memory.getValueAddress(storeOp.getHostTarget()), + addressOf(storeOp.getHostTarget(), knowledge), storeOp.getHostTargetOffset(), - memory.getValueAddress(storeOp.getDeviceSource()), + addressOf(storeOp.getDeviceSource(), knowledge), storeOp.getDeviceSourceOffset(), storeOp.getSize()); } -void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const { +void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const { emitMemCopyOp("lmv", - memory.getValueAddress(lmvOp.getTarget()), + addressOf(lmvOp.getTarget(), knowledge), lmvOp.getTargetOffset(), - memory.getValueAddress(lmvOp.getSource()), + addressOf(lmvOp.getSource(), knowledge), lmvOp.getSourceOffset(), lmvOp.getSize(), "len"); } -void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const { +void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const { emitCommunicationOp( - "recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize()); + "recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize()); } -void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const { - emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize()); +void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const { + emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize()); } template -void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) { - emitMvmOp( - mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0); +void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, + MVMTy mvmLikeOp, + bool transposeMatrix, + const StaticValueKnowledge& knowledge) { + emitMvmOp(mvmId, addressOf(mvmLikeOp.getOutputBuffer(), knowledge), 0, addressOf(mvmLikeOp.getInput(), knowledge), 0); // TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix) } @@ -249,10 +251,10 @@ static size_t getValueSizeInBytes(mlir::Value value) { return type.getNumElements() * type.getElementTypeBitWidth() / 8; } -void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const { - auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer()); - auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs()); - auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs()); +void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge); + auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge); + auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge); setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; @@ -265,10 +267,10 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const { - auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer()); - auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs()); - auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs()); +void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vvsubOp.getOutputBuffer(), knowledge); + auto lhsAddr = addressOf(vvsubOp.getLhs(), knowledge); + auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge); setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; @@ -281,10 +283,10 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const { - auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer()); - auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs()); - auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs()); +void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vvmulOp.getOutputBuffer(), knowledge); + auto lhsAddr = addressOf(vvmulOp.getLhs(), knowledge); + auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge); setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; @@ -297,10 +299,10 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const { - auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer()); - auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs()); - auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs()); +void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vvmaxOp.getOutputBuffer(), knowledge); + auto lhsAddr = addressOf(vvmaxOp.getLhs(), knowledge); + auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge); setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; @@ -313,10 +315,10 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const { - auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer()); - auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs()); - auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs()); +void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vvdmulOp.getOutputBuffer(), knowledge); + auto lhsAddr = addressOf(vvdmulOp.getLhs(), knowledge); + auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge); setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0); json::Object json; @@ -329,9 +331,9 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const { - auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer()); - auto inputAddr = memory.getValueAddress(vavgOp.getInput()); +void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vavgOp.getOutputBuffer(), knowledge); + auto inputAddr = addressOf(vavgOp.getInput(), knowledge); setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; @@ -344,9 +346,9 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const { - auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer()); - auto inputAddr = memory.getValueAddress(vreluOp.getInput()); +void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vreluOp.getOutputBuffer(), knowledge); + auto inputAddr = addressOf(vreluOp.getInput(), knowledge); setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; @@ -358,9 +360,9 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const { - auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer()); - auto inputAddr = memory.getValueAddress(vtanhOp.getInput()); +void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vtanhOp.getOutputBuffer(), knowledge); + auto inputAddr = addressOf(vtanhOp.getInput(), knowledge); setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; @@ -372,9 +374,9 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const { - auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer()); - auto inputAddr = memory.getValueAddress(vsigmOp.getInput()); +void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vsigmOp.getOutputBuffer(), knowledge); + auto inputAddr = addressOf(vsigmOp.getInput(), knowledge); setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; @@ -386,9 +388,9 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const { - auto outputBufferAddr = memory.getValueAddress(vsoftmaxOp.getOutputBuffer()); - auto inputAddr = memory.getValueAddress(vsoftmaxOp.getInput()); +void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const { + auto outputBufferAddr = addressOf(vsoftmaxOp.getOutputBuffer(), knowledge); + auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge); setupRdRs1(outputBufferAddr, 0, inputAddr, 0); json::Object json; @@ -400,9 +402,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const { emitInstruction(std::move(json)); } -void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const { - auto srcAddr = memory.getValueAddress(transposeOp.getInput()); - auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer()); +void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const { + auto srcAddr = addressOf(transposeOp.getInput(), knowledge); + auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge); auto srcType = cast(transposeOp.getInput().getType()); auto srcShape = srcType.getShape(); @@ -510,57 +512,58 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& } /// Dispatch all operations in a core region to the appropriate code generator. +/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is +/// fully resolved before the JSON instructions are emitted. /// Returns the number of emitted instructions, or -1 on failure. -static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { +static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { size_t processedOperations = 0; - for (auto& op : coreOp.getBody().front()) { - if (isa(op)) - continue; - - if (auto loadOp = dyn_cast(op)) - coreCodeGen.codeGenLoadOp(loadOp); - else if (auto storeOp = dyn_cast(op)) - coreCodeGen.codeGenStoreOp(storeOp); - else if (auto lmvOp = dyn_cast(op)) - coreCodeGen.codeGenLmvOp(lmvOp); - else if (auto receiveOp = dyn_cast(op)) - coreCodeGen.codeGenReceiveOp(receiveOp); - else if (auto sendOp = dyn_cast(op)) - coreCodeGen.codeGenSendOp(sendOp); - else if (auto vmmOp = dyn_cast(op)) - coreCodeGen.codeGenMVMLikeOp(vmmOp.getWeightIndex(), vmmOp, true); - else if (auto mvmOp = dyn_cast(op)) - coreCodeGen.codeGenMVMLikeOp(mvmOp.getWeightIndex(), mvmOp, false); - else if (auto transposeOp = dyn_cast(op)) - coreCodeGen.codeGenTransposeOp(transposeOp); - else if (auto vvaddOp = dyn_cast(op)) - coreCodeGen.codeGenVVAddOp(vvaddOp); - else if (auto vvsubOp = dyn_cast(op)) - coreCodeGen.codeGenVVSubOp(vvsubOp); - else if (auto vvmulOp = dyn_cast(op)) - coreCodeGen.codeGenVVMulOp(vvmulOp); - else if (auto vvmaxOp = dyn_cast(op)) - coreCodeGen.codeGenVVMaxOp(vvmaxOp); - else if (auto vvdmulOp = dyn_cast(op)) - coreCodeGen.codeGenVVDMulOp(vvdmulOp); - else if (auto vavgOp = dyn_cast(op)) - coreCodeGen.codeGenVAvgOp(vavgOp); - else if (auto vreluOp = dyn_cast(op)) - coreCodeGen.codeGenVReluOp(vreluOp); - else if (auto vtanhOp = dyn_cast(op)) - coreCodeGen.codeGenVTanhOp(vtanhOp); - else if (auto vsigmOp = dyn_cast(op)) - coreCodeGen.codeGenVSigmOp(vsigmOp); - else if (auto vsoftmaxOp = dyn_cast(op)) - coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp); - else { - op.emitError("Unsupported codegen for this operation"); - op.dump(); - return -1; - } - processedOperations++; - } - return processedOperations; + auto result = + walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) { + if (auto loadOp = dyn_cast(op)) + coreCodeGen.codeGenLoadOp(loadOp, knowledge); + else if (auto storeOp = dyn_cast(op)) + coreCodeGen.codeGenStoreOp(storeOp, knowledge); + else if (auto lmvOp = dyn_cast(op)) + coreCodeGen.codeGenLmvOp(lmvOp, knowledge); + else if (auto receiveOp = dyn_cast(op)) + coreCodeGen.codeGenReceiveOp(receiveOp, knowledge); + else if (auto sendOp = dyn_cast(op)) + coreCodeGen.codeGenSendOp(sendOp, knowledge); + else if (auto vmmOp = dyn_cast(op)) + coreCodeGen.codeGenMVMLikeOp(vmmOp.getWeightIndex(), vmmOp, true, knowledge); + else if (auto mvmOp = dyn_cast(op)) + coreCodeGen.codeGenMVMLikeOp(mvmOp.getWeightIndex(), mvmOp, false, knowledge); + else if (auto transposeOp = dyn_cast(op)) + coreCodeGen.codeGenTransposeOp(transposeOp, knowledge); + else if (auto vvaddOp = dyn_cast(op)) + coreCodeGen.codeGenVVAddOp(vvaddOp, knowledge); + else if (auto vvsubOp = dyn_cast(op)) + coreCodeGen.codeGenVVSubOp(vvsubOp, knowledge); + else if (auto vvmulOp = dyn_cast(op)) + coreCodeGen.codeGenVVMulOp(vvmulOp, knowledge); + else if (auto vvmaxOp = dyn_cast(op)) + coreCodeGen.codeGenVVMaxOp(vvmaxOp, knowledge); + else if (auto vvdmulOp = dyn_cast(op)) + coreCodeGen.codeGenVVDMulOp(vvdmulOp, knowledge); + else if (auto vavgOp = dyn_cast(op)) + coreCodeGen.codeGenVAvgOp(vavgOp, knowledge); + else if (auto vreluOp = dyn_cast(op)) + coreCodeGen.codeGenVReluOp(vreluOp, knowledge); + else if (auto vtanhOp = dyn_cast(op)) + coreCodeGen.codeGenVTanhOp(vtanhOp, knowledge); + else if (auto vsigmOp = dyn_cast(op)) + coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge); + else if (auto vsoftmaxOp = dyn_cast(op)) + coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); + else { + op.emitError("Unsupported codegen for this operation"); + op.dump(); + return failure(); + } + processedOperations++; + return success(); + }); + return failed(result) ? -1 : static_cast(processedOperations); } /// Write crossbar weight matrices as padded binary files for a single core. @@ -739,7 +742,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: PimCodeGen coreCodeGen(memory, coreFileStream); memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp); - int64_t processedOperations = codeGenCoreOps(coreOp, coreCodeGen); + int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen); if (processedOperations < 0) return CompilerFailure; assert(processedOperations > 0); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 81f340d..cd0fd4a 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -4,6 +4,7 @@ #include "llvm/Support/JSON.h" #include "onnx-mlir/Compiler/OMCompilerTypes.h" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" namespace onnx_mlir { @@ -50,13 +51,17 @@ public: PimMemory& getOrCreateDeviceMem(size_t id); - size_t getValueAddress(mlir::Value value) const; + size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; }; class PimCodeGen { PimAcceleratorMemory& memory; llvm::raw_fd_ostream& coreFileStream; + size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const { + return memory.getValueAddress(value, knowledge); + } + static llvm::json::Object createEmptyOffset(); void emitInstruction(llvm::json::Object instruction) const; @@ -80,27 +85,27 @@ public: PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson) : memory(memory), coreFileStream(coreJson) {} - void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const; - void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const; - void codeGenLmvOp(pim::PimMemCopyOp lmvOp) const; + void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; + void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const; + void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const; - void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const; - void codeGenSendOp(pim::PimSendOp sendOp) const; + void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const; + void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const; template - void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix); + void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge); - void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const; - void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const; - void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const; - void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const; - void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const; - void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const; - void codeGenVReluOp(pim::PimVReluOp vreluOp) const; - void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const; - void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const; - void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const; - void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const; + void codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const; + void codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const; + void codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const; + void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const; + void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const; + void codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const; + void codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const; + void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const; + void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const; + void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const; + void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const; }; OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName); diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index d26673f..994deff 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -26,6 +26,7 @@ add_pim_library(OMONNXToSpatial ONNXToSpatialIncGen LINK_LIBS PUBLIC + MLIRSCFDialect MLIRTosaDialect OMCompilerOptions OMPimCompilerOptions diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index d140ee2..2c5f024 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -75,7 +76,11 @@ void ONNXToSpatialPass::runOnOperation() { } ConversionTarget target(*ctx); - target.addLegalDialect(); + target.addLegalDialect(); target.addDynamicallyLegalOp( [](ONNXMatMulOp op) { return cast(op.getY().getType()).getRank() != 2; }); target.addIllegalOp(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 1c4f38b..7e56e6c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" @@ -169,44 +170,60 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, paddedInput = padOp.getResult(); } - // Build im2col [numPatches, patchSize]: - // For each batch/output position (n, oh, ow), extract the patch from x - SmallVector im2colRows; - im2colRows.reserve(numPatches); - for (int64_t n = 0; n < batchSize; n++) { - for (int64_t oh = 0; oh < outHeight; oh++) { - for (int64_t ow = 0; ow < outWidth; ow++) { - SmallVector offsets = {rewriter.getIndexAttr(n), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(oh * strideHeight), - rewriter.getIndexAttr(ow * strideWidth)}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(numChannelsIn), - rewriter.getIndexAttr(wHeight), - rewriter.getIndexAttr(wWidth)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(dilationHeight), - rewriter.getIndexAttr(dilationWidth)}; - auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); - Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); + // Build im2col [numPatches, patchSize] incrementally to keep the IR small + // until the late PIM unrolling step. + Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); + auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches); + auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch); + auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth); + auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); + auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); - // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] - Value row = tensor::CollapseShapeOp::create(rewriter, - loc, - rowType, - patch, - SmallVector { - {0}, - {1, 2, 3} - }); - im2colRows.push_back(row); - } - } - } + auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); + rewriter.setInsertionPointToStart(im2colLoop.getBody()); - // Concatenate all rows: [numPatches, patchSize] - Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows); + Value patchIndex = im2colLoop.getInductionVar(); + Value im2colAcc = im2colLoop.getRegionIterArgs().front(); + + Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); + Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); + Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); + Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); + Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); + Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); + + SmallVector offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(numChannelsIn), + rewriter.getIndexAttr(wHeight), + rewriter.getIndexAttr(wWidth)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(dilationHeight), + rewriter.getIndexAttr(dilationWidth)}; + auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); + Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); + + Value row = tensor::CollapseShapeOp::create(rewriter, + loc, + rowType, + patch, + SmallVector { + {0}, + {1, 2, 3} + }); + + SmallVector rowOffsets = {patchIndex, rewriter.getIndexAttr(0)}; + SmallVector rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; + SmallVector rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value updatedIm2col = + tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); + scf::YieldOp::create(rewriter, loc, updatedIm2col); + + rewriter.setInsertionPointAfter(im2colLoop); + Value im2col = im2colLoop.getResult(0); spatial::SpatYieldOp::create(rewriter, loc, im2col); }); diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index 166110a..d8222c8 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -12,6 +12,7 @@ add_pim_library(OMSpatialToPim SpatialToPimIncGen LINK_LIBS PUBLIC + MLIRSCFDialect MLIRTosaDialect OMCompilerOptions OMPimCommon diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 07148d3..0ea92c4 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" @@ -134,7 +135,12 @@ void SpatialToPimPass::runOnOperation() { MLIRContext* ctx = moduleOp.getContext(); ConversionTarget target(*ctx); - target.addLegalDialect(); + target.addLegalDialect(); RewritePatternSet patterns(ctx); populateWithGenerated(patterns); diff --git a/src/PIM/Pass/CMakeLists.txt b/src/PIM/Pass/CMakeLists.txt index db31157..4c4507b 100644 --- a/src/PIM/Pass/CMakeLists.txt +++ b/src/PIM/Pass/CMakeLists.txt @@ -13,6 +13,7 @@ add_pim_library(OMPimPasses LINK_LIBS PUBLIC MLIRLinalgDialect + MLIRSCFDialect OMCompilerUtils OMPimCommon ) diff --git a/src/PIM/Pass/Pim/ConstantFolding/Common.cpp b/src/PIM/Pass/Pim/ConstantFolding/Common.cpp index 16f121d..c0b3b8f 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Common.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Common.cpp @@ -85,12 +85,8 @@ FailureOr getStaticSubviewInfo(Value value) { StaticSubviewInfo info; info.source = source; info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end()); - for (OpFoldResult offset : subviewOp.getMixedOffsets()) { - auto staticOffset = getConstantIntValue(offset); - if (!staticOffset) - return failure(); - info.offsets.push_back(*staticOffset); - } + SmallVector mixedOffsets = subviewOp.getMixedOffsets(); + info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end()); for (OpFoldResult size : subviewOp.getMixedSizes()) { auto staticSize = getConstantIntValue(size); if (!staticSize) @@ -106,14 +102,16 @@ FailureOr getStaticSubviewInfo(Value value) { return info; } -int64_t -getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef outerIndices, int64_t elementByteWidth) { - SmallVector sourceIndices; - sourceIndices.reserve(info.sourceShape.size()); - for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim) - sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]); - sourceIndices.push_back(info.offsets.back()); - return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth; +FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& info) { + SmallVector staticOffsets; + staticOffsets.reserve(info.offsets.size()); + for (OpFoldResult offset : info.offsets) { + auto staticOffset = getConstantIntValue(offset); + if (!staticOffset) + return failure(); + staticOffsets.push_back(*staticOffset); + } + return staticOffsets; } } // namespace onnx_mlir diff --git a/src/PIM/Pass/Pim/ConstantFolding/Common.hpp b/src/PIM/Pass/Pim/ConstantFolding/Common.hpp index 0483b6a..f41c4a0 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Common.hpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Common.hpp @@ -14,7 +14,7 @@ namespace onnx_mlir { struct StaticSubviewInfo { mlir::Value source; llvm::SmallVector sourceShape; - llvm::SmallVector offsets; + llvm::SmallVector offsets; llvm::SmallVector sizes; llvm::SmallVector strides; }; @@ -34,8 +34,7 @@ llvm::FailureOr getDenseGlobalValue(mlir::ModuleOp modu llvm::FailureOr getStaticSubviewInfo(mlir::Value value); -int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, - llvm::ArrayRef outerIndices, - int64_t elementByteWidth); +/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic. +llvm::FailureOr> getStaticSubviewOffsets(const StaticSubviewInfo& info); } // namespace onnx_mlir diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp index 531d5c3..ef1adb2 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Constant.cpp @@ -120,7 +120,15 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); rewriter.setInsertionPoint(mapOp); - rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp); + auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8; + pim::PimMemCopyOp::create(rewriter, + mapOp.getLoc(), + initType, + mapOp.getInit(), + getGlobalOp.getResult(), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(sizeInBytes)); rewriter.eraseOp(mapOp); return success(); } @@ -416,6 +424,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { return failure(); if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) return failure(); + auto staticOffsets = getStaticSubviewOffsets(*srcSubview); + if (failed(staticOffsets)) + return failure(); auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); const int64_t numResultElements = resultTensorType.getNumElements(); @@ -428,7 +439,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); SmallVector sourceIndices; sourceIndices.reserve(resultIndices.size()); - for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices)) + for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices)) sourceIndices.push_back(off + idx); int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides); resultValues[i] = sourceValues[srcLinear]; diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp index 5be11b0..acdbc58 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp @@ -1,3 +1,5 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" + #include "../Common.hpp" #include "../Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" @@ -8,6 +10,62 @@ using namespace mlir; namespace onnx_mlir { namespace { +static bool isSubviewContiguous(const StaticSubviewInfo& info) { + if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; })) + return false; + + auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()), + llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend())); + auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { + auto [size, dimension] = sizeAndShape; + return size != dimension; + }); + if (firstDifferentSize == sizesAndShape.end()) + return true; + + ++firstDifferentSize; + return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) { + auto [size, _dimension] = sizeAndShape; + return size == 1; + }); +} + +static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) { + if (extraOffset == 0) + return baseOffset; + + if (auto attr = dyn_cast(baseOffset)) { + auto integerAttr = dyn_cast(attr); + assert(integerAttr && "expected integer offset attribute"); + return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset); + } + + auto value = cast(baseOffset); + auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset); + return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult(); +} + +static Value buildSubviewChunk(const StaticSubviewInfo& info, + ArrayRef outerIndices, + Location loc, + PatternRewriter& rewriter) { + SmallVector chunkOffsets; + SmallVector chunkSizes; + SmallVector chunkStrides; + chunkOffsets.reserve(info.offsets.size()); + chunkSizes.reserve(info.sizes.size()); + chunkStrides.reserve(info.strides.size()); + + for (size_t dim = 0; dim < info.sizes.size(); ++dim) { + int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0; + chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter)); + chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back())); + chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim])); + } + + return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); +} + template static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, Value dst, @@ -19,12 +77,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, CreateCopyOp createCopyOp) { auto srcSubview = getStaticSubviewInfo(src); auto dstSubview = getStaticSubviewInfo(dst); - const bool splitSrc = - succeeded(srcSubview) - && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); - const bool splitDst = - succeeded(dstSubview) - && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); + const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview); + const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview); if (!splitSrc && !splitDst) return failure(); @@ -35,9 +89,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, if (sourceType.getElementType() != dstType.getElementType()) return failure(); - if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) + if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))) return failure(); - if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })) + if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))) return failure(); ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); @@ -64,18 +118,11 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { SmallVector outerIndices = outerShape.empty() ? SmallVector {} : delinearizeIndex(linearIndex, outerShape, outerStrides); - const int64_t srcByteOffset = - srcOffset - + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); - const int64_t dstByteOffset = - dstOffset - + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes); - createCopyOp(splitDst ? cast(dstSubview->source.getType()) : dstType, - splitDst ? dstSubview->source : dst, - splitSrc ? srcSubview->source : src, - dstByteOffset, - srcByteOffset, - sliceBytes); + Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst; + Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src; + const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes; + const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes; + createCopyOp(cast(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes); } return success(); @@ -198,6 +245,9 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePatternstrides, [](int64_t stride) { return stride != 1; })) return failure(); + auto staticOffsets = getStaticSubviewOffsets(*subviewInfo); + if (failed(staticOffsets)) + return failure(); auto sourceType = dyn_cast(denseAttr->getType()); if (!sourceType || !sourceType.hasStaticShape()) @@ -217,7 +267,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern sourceIndices; sourceIndices.reserve(resultIndices.size()); - for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices)) + for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices)) sourceIndices.push_back(off + idx); resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)]; } diff --git a/src/PIM/Pass/Pim/VerificationPass.cpp b/src/PIM/Pass/Pim/VerificationPass.cpp index 809685a..0e67cc2 100644 --- a/src/PIM/Pass/Pim/VerificationPass.cpp +++ b/src/PIM/Pass/Pim/VerificationPass.cpp @@ -132,38 +132,37 @@ private: } static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) { - bool hasFailure = false; - for (Operation& op : coreOp.getBody().front()) { - if (isa(op)) - continue; + return walkPimCoreBlock( + coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) { + bool hasFailure = false; + for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { + if (!isa(operand.getType())) + continue; - for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { - if (!isa(operand.getType())) - continue; + auto resolvedAddress = resolveContiguousAddress(operand, knowledge); + if (failed(resolvedAddress)) { + op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; + hasFailure = true; + continue; + } - auto resolvedAddress = resolveContiguousAddress(operand); - if (failed(resolvedAddress)) { - op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage"; - hasFailure = true; - continue; - } + if (isExplicitHostOperand(&op, operandIndex)) { + if (!isCodegenAddressableValue(operand)) { + op.emitOpError() << "host operand #" << operandIndex + << " is not backed by contiguous addressable storage"; + hasFailure = true; + } + continue; + } - if (isExplicitHostOperand(&op, operandIndex)) { - if (!isCodegenAddressableValue(operand)) { - op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage"; + if (!isa(resolvedAddress->base.getDefiningOp())) { + op.emitOpError() << "operand #" << operandIndex + << " must be backed by device-local memory; materialize host values with pim.memcp_hd"; hasFailure = true; } - continue; } - - if (!isa(resolvedAddress->base.getDefiningOp())) { - op.emitOpError() << "operand #" << operandIndex - << " must be backed by device-local memory; materialize host values with pim.memcp_hd"; - hasFailure = true; - } - } - } - return success(!hasFailure); + return success(!hasFailure); + }); } static LogicalResult verifyAddressOnlyHostOp(Operation* op) { diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 21bd62c..29ceb15 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -3,6 +3,8 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" @@ -57,12 +59,14 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry); mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry); mlir::arith::registerBufferizableOpInterfaceExternalModels(registry); mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); + mlir::scf::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); pim::registerOpBufferizationInterfaces(registry); diff --git a/validation/.gitignore b/validation/.gitignore index 70dfb95..ad1da66 100644 --- a/validation/.gitignore +++ b/validation/.gitignore @@ -3,3 +3,8 @@ operations/**/outputs operations/**/raptor operations/**/runner operations/**/simulation +networks/**/inputs +networks/**/outputs +networks/**/raptor +networks/**/runner +networks/**/simulation