From 1a5d7d2a3fbfd2646c0af5e1c97cf940de05c7e9 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 27 May 2026 16:15:10 +0200 Subject: [PATCH] fix bufferization and weight emission after new gemm patterns --- src/PIM/Common/IR/WeightUtils.cpp | 188 +++++++++++- src/PIM/Common/IR/WeightUtils.hpp | 38 ++- src/PIM/Compiler/PimCodeGen.cpp | 99 ++++--- src/PIM/Compiler/PimWeightEmitter.cpp | 277 ++++-------------- src/PIM/Compiler/PimWeightEmitter.hpp | 13 +- .../Bufferization/BufferizationUtils.cpp | 9 + .../Bufferization/BufferizationUtils.hpp | 1 + .../OpBufferizationInterfaces.cpp | 14 +- src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 21 +- src/PIM/Pass/PimCodegen/VerificationPass.cpp | 6 + 10 files changed, 349 insertions(+), 317 deletions(-) diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index ab2b113..4cd8168 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -1,8 +1,13 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -19,6 +24,50 @@ void markWeightAlways(mlir::Operation* op) { namespace { +CompiledIndexExpr makeConstantExpr(int64_t constant) { + CompiledIndexExprNode expr; + expr.kind = CompiledIndexExprNode::Kind::Constant; + expr.constant = constant; + return CompiledIndexExpr(std::make_shared(std::move(expr))); +} + +CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) { + CompiledIndexExprNode expr; + expr.kind = kind; + expr.operands = {std::move(lhs), std::move(rhs)}; + return CompiledIndexExpr(std::make_shared(std::move(expr))); +} + +CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) { + return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs)); +} + +CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) { + return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs)); +} + +mlir::Value stripWeightViewOps(mlir::Value value) { + while (true) { + if (auto subviewOp = value.getDefiningOp()) { + value = subviewOp.getSource(); + continue; + } + if (auto castOp = value.getDefiningOp()) { + value = castOp.getSource(); + continue; + } + if (auto collapseOp = value.getDefiningOp()) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = value.getDefiningOp()) { + value = expandOp.getSrc(); + continue; + } + return value; + } +} + template bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { auto weightArg = parentOp.getWeightArgument(weightIndex); @@ -96,35 +145,31 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_refwalk([&](pim::PimCoreOp coreOp) { coreOp.walk([&](pim::PimVMMOp vmmOp) { - for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) - if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) { - callback(coreOp->getOpOperand(weightIndex)); - break; - } + if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight())) + callback(coreOp->getOpOperand(*weightIndex)); }); }); root->walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOp.walk([&](pim::PimVMMOp vmmOp) { - for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex) - if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) { - callback(coreBatchOp->getOpOperand(weightIndex)); - break; - } + if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight())) + callback(coreBatchOp->getOpOperand(*weightIndex)); }); }); } -std::optional resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) { +std::optional resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) { + weight = stripWeightViewOps(weight); + if (auto coreOp = mlir::dyn_cast_or_null(weightOwner)) { for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) - if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) + if (coreOp.getWeightArgument(weightIndex) == weight) return weightIndex; return std::nullopt; } if (auto coreBatchOp = mlir::dyn_cast_or_null(weightOwner)) { for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex) - if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) + if (coreBatchOp.getWeightArgument(weightIndex) == weight) return weightIndex; return std::nullopt; } @@ -132,4 +177,121 @@ std::optional resolveWeightIndex(mlir::Operation* weightOwner, pim::Pi return std::nullopt; } +std::optional resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) { + return resolveWeightIndex(weightOwner, vmmOp.getWeight()); +} + +llvm::FailureOr +resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) { + llvm::SmallVector viewOps; + mlir::Value current = weight; + + while (true) { + if (auto defOp = current.getDefiningOp()) { + if (auto getGlobalOp = mlir::dyn_cast(defOp)) { + auto moduleOp = weightOwner ? weightOwner->getParentOfType() : mlir::ModuleOp {}; + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp || !globalOp.getInitialValue()) + return mlir::failure(); + + auto denseAttr = mlir::dyn_cast(*globalOp.getInitialValue()); + if (!denseAttr) + return mlir::failure(); + + ResolvedWeightView view; + view.globalOp = globalOp; + view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); + view.strides = computeRowMajorStrides(view.shape); + + CompiledIndexExpr offsetExpr = makeConstantExpr(0); + for (mlir::Operation* viewOp : llvm::reverse(viewOps)) { + if (auto subview = mlir::dyn_cast(viewOp)) { + llvm::SmallVector nextStrides; + nextStrides.reserve(subview.getMixedOffsets().size()); + for (auto [offset, stride, sourceStride] : + llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) { + CompiledIndexExpr offsetValue = makeConstantExpr(0); + if (auto attr = mlir::dyn_cast(offset)) { + auto intAttr = mlir::dyn_cast(attr); + if (!intAttr) + return mlir::failure(); + offsetValue = makeConstantExpr(intAttr.getInt()); + } + else if (auto value = mlir::dyn_cast(offset)) { + auto compiledOffset = compileIndexExpr(value); + if (failed(compiledOffset)) + return mlir::failure(); + offsetValue = *compiledOffset; + } + else { + return mlir::failure(); + } + offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride)); + nextStrides.push_back(stride * sourceStride); + } + view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); + view.strides = std::move(nextStrides); + continue; + } + + if (auto collapse = mlir::dyn_cast(viewOp)) { + if (view.strides != computeRowMajorStrides(view.shape)) + return mlir::failure(); + auto resultType = mlir::cast(collapse.getResult().getType()); + view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); + view.strides = computeRowMajorStrides(view.shape); + continue; + } + + if (auto expand = mlir::dyn_cast(viewOp)) { + if (view.strides != computeRowMajorStrides(view.shape)) + return mlir::failure(); + auto resultType = mlir::cast(expand.getResult().getType()); + view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); + view.strides = computeRowMajorStrides(view.shape); + } + } + + auto resolvedOffset = offsetExpr.evaluate(knowledge); + if (failed(resolvedOffset)) + return mlir::failure(); + view.offset = *resolvedOffset; + return view; + } + + if (mlir::isa(defOp)) { + viewOps.push_back(defOp); + if (auto subview = mlir::dyn_cast(defOp)) + current = subview.getSource(); + else if (auto collapse = mlir::dyn_cast(defOp)) + current = collapse.getSrc(); + else + current = mlir::cast(defOp).getSrc(); + continue; + } + + if (auto castOp = mlir::dyn_cast(defOp)) { + current = castOp.getSource(); + continue; + } + + return mlir::failure(); + } + + auto weightIndex = resolveWeightIndex(weightOwner, current); + if (!weightIndex) + return mlir::failure(); + + if (auto coreOp = mlir::dyn_cast_or_null(weightOwner)) { + current = coreOp.getWeights()[*weightIndex]; + continue; + } + if (auto coreBatchOp = mlir::dyn_cast_or_null(weightOwner)) { + current = coreBatchOp.getWeights()[*weightIndex]; + continue; + } + return mlir::failure(); + } +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/WeightUtils.hpp b/src/PIM/Common/IR/WeightUtils.hpp index c02c839..b14fc71 100644 --- a/src/PIM/Common/IR/WeightUtils.hpp +++ b/src/PIM/Common/IR/WeightUtils.hpp @@ -1,6 +1,7 @@ #pragma once -#include "mlir/IR/Operation.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Value.h" #include "llvm/ADT/SmallVector.h" @@ -10,12 +11,24 @@ #include +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; namespace onnx_mlir { +struct ResolvedWeightView { + mlir::memref::GlobalOp globalOp; + llvm::SmallVector shape; + llvm::SmallVector strides; + int64_t offset = 0; + + bool operator==(const ResolvedWeightView& other) const { + return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset; + } +}; + bool hasWeightAlways(mlir::Operation* op); /// Tags an op as producing a value that should stay materialized as a reusable @@ -32,24 +45,21 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value); /// passes can identify globals that must remain weight-backed. void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); +std::optional resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight); +std::optional resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp); +llvm::FailureOr +resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {}); + template llvm::SmallVector getUsedWeightIndices(CoreLikeOpTy coreLikeOp) { llvm::SmallVector indices; - auto addWeight = [&](mlir::Value weight) { - for (unsigned weightIndex = 0; weightIndex < coreLikeOp.getWeights().size(); ++weightIndex) { - if (coreLikeOp.getWeightArgument(weightIndex) != weight) - continue; - if (!llvm::is_contained(indices, weightIndex)) - indices.push_back(weightIndex); - return; - } - }; - - coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); }); + coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { + auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight()); + if (weightIndex && !llvm::is_contained(indices, *weightIndex)) + indices.push_back(*weightIndex); + }); llvm::sort(indices); return indices; } -std::optional resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp); - } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 1cfd637..ff5e2e9 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -814,7 +814,7 @@ static SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { struct CoreEmissionResult { OnnxMlirCompilerErrorCodes status = CompilerSuccess; MemoryReportRow reportRow; - llvm::SmallVector usedWeightIndices; + llvm::SmallVector usedWeights; }; template @@ -879,7 +879,6 @@ struct CompiledCoreNode { Kind kind = Kind::Op; Operation* op = nullptr; CompiledCoreOpKind opKind = CompiledCoreOpKind::Load; - std::optional weightIndex; CompiledIndexExpr lowerBound; CompiledIndexExpr upperBound; CompiledIndexExpr step; @@ -978,12 +977,6 @@ static LogicalResult compileCoreEmissionPlan(Block& block, opNode.kind = CompiledCoreNode::Kind::Op; opNode.op = &op; opNode.opKind = *opKind; - if (auto vmmOp = dyn_cast(op)) { - auto weightIndex = onnx_mlir::resolveWeightIndex(weightOwner, vmmOp); - if (!weightIndex) - return failure(); - opNode.weightIndex = *weightIndex; - } plan.push_back(std::move(opNode)); } return success(); @@ -992,6 +985,9 @@ static LogicalResult compileCoreEmissionPlan(Block& block, static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl& plan, PimCodeGen& coreCodeGen, StaticValueKnowledge& knowledge, + llvm::function_ref(pim::PimVMMOp, + const StaticValueKnowledge&)> + resolveWeightSlot, size_t& processedOperations, std::optional batchLane = std::nullopt, std::optional batchLaneCount = std::nullopt) { @@ -1015,7 +1011,7 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl(forOp.getRegion().front().getTerminator()); @@ -1048,9 +1044,10 @@ static LogicalResult executeCompiledCorePlan(const llvm::SmallVectorImpl(node.op), knowledge); break; case CompiledCoreOpKind::Vmm: - assert(node.weightIndex && "compiled VMM op must have cached weight index"); - coreCodeGen.codeGenMVMLikeOp( - *node.weightIndex, cast(node.op), true, knowledge); + if (auto weightSlot = resolveWeightSlot(cast(node.op), knowledge); succeeded(weightSlot)) + coreCodeGen.codeGenMVMLikeOp(*weightSlot, cast(node.op), true, knowledge); + else + return failure(); break; case CompiledCoreOpKind::Transpose: coreCodeGen.codeGenTransposeOp(cast(node.op), knowledge); @@ -1138,6 +1135,9 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen, const StaticValueKnowledge& initialKnowledge, Operation* weightOwner, + llvm::function_ref(pim::PimVMMOp, + const StaticValueKnowledge&)> + resolveWeightSlot, std::optional batchLane = std::nullopt, std::optional batchLaneCount = std::nullopt) { llvm::SmallVector plan; @@ -1146,7 +1146,8 @@ static int64_t codeGenCoreOps(Block& block, size_t processedOperations = 0; StaticValueKnowledge knowledge = initialKnowledge; - auto result = executeCompiledCorePlan(plan, coreCodeGen, knowledge, processedOperations, batchLane, batchLaneCount); + auto result = + executeCompiledCorePlan(plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount); return failed(result) ? -1 : static_cast(processedOperations); } @@ -1174,9 +1175,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: size_t maxCoreId = 0; uint64_t nextBatchReportId = 0; - // Create Weight Folder - auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath); - SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); SmallDenseMap materializedHostGlobals = collectMaterializedHostGlobals(moduleOp, funcOp, memory); @@ -1238,11 +1236,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: ++nextBatchReportId; } - auto linkCoreWeights = [&](size_t originalCoreId, - size_t coreId, - ArrayRef usedIndices, - ValueRange weights, - Operation* weightOwner, + auto linkCoreWeights = [&](size_t coreId, + ArrayRef weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes { auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { @@ -1250,20 +1245,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: return InvalidOutputFileAccess; } - auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId]; - for (unsigned index : usedIndices) { - if (index >= weights.size()) { - weightOwner->emitWarning("Weight index " + std::to_string(index) + " is out of range"); - assert(index < weights.size() && "Weight index is out of range"); - } - mlir::Value weight = weights[index]; - xbarsPerGroup.push_back(index); - assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); - auto& fileName = mapWeightToFile[weight]; + for (auto [slot, fileName] : llvm::enumerate(weightFiles)) { + xbarsPerGroup.push_back(static_cast(slot)); if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, - coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) { + coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) { errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " - << (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") + << (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin") << "\nError:" << error.message() << '\n'; return InvalidOutputFileAccess; } @@ -1275,6 +1262,22 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: auto emitJob = [&](const CoreEmissionJob& job) -> CoreEmissionResult { CoreEmissionResult result; PimAcceleratorMemory jobMemory(memory.memEntriesMap, false); + llvm::SmallVector usedWeights; + + auto resolveWeightSlot = [&](pim::PimVMMOp vmmOp, + const StaticValueKnowledge& knowledge) -> llvm::FailureOr { + auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge); + if (failed(weightView)) { + vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen"); + return failure(); + } + + if (auto it = llvm::find(usedWeights, *weightView); it != usedWeights.end()) + return static_cast(std::distance(usedWeights.begin(), it)); + + usedWeights.push_back(*weightView); + return static_cast(usedWeights.size() - 1); + }; std::error_code errorCode; auto outputCorePath = outputDirPath + "/core_" + std::to_string(job.emittedCoreId) + ".pim"; @@ -1307,21 +1310,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId); deviceMemory.allocateCore(coreOp); - int64_t processedOperations = - codeGenCoreOps(coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation()); + int64_t processedOperations = codeGenCoreOps( + coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot); if (processedOperations < 0) { result.status = CompilerFailure; return result; } assert(processedOperations > 0); result.reportRow = deviceMemory.getReportRow(); - result.usedWeightIndices = getUsedWeightIndices(coreOp); + result.usedWeights = std::move(usedWeights); } else { auto coreBatchOp = cast(job.coreLikeOp); aliasMaterializedHostGlobals(coreBatchOp, moduleOp, materializedHostGlobals, jobMemory); auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId); - result.usedWeightIndices = getUsedWeightIndices(coreBatchOp); for (unsigned lane : job.lanes) { StaticValueKnowledge knowledge; @@ -1335,6 +1337,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: coreCodeGen, knowledge, coreBatchOp.getOperation(), + resolveWeightSlot, lane, static_cast(coreBatchOp.getLaneCount())); if (processedOperations < 0) { @@ -1345,6 +1348,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: } result.reportRow = deviceMemory.getReportRow(); + result.usedWeights = std::move(usedWeights); } pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount()); @@ -1368,14 +1372,23 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: if (jobResults[jobIndex].status != CompilerSuccess) return jobResults[jobIndex].status; + llvm::SmallVector weightRequests; + weightRequests.reserve(jobs.size()); + for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) { + WeightFileRequest request; + request.coreId = jobs[jobIndex].emittedCoreId; + request.weights = jobResults[jobIndex].usedWeights; + weightRequests.push_back(std::move(request)); + } + auto mapCoreWeightToFileName = createAndPopulateWeightFolder(weightRequests, outputDirPath); + for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) { const CoreEmissionJob& job = jobs[jobIndex]; const CoreEmissionResult& result = jobResults[jobIndex]; json::Array xbarsPerGroup; if (auto coreOp = dyn_cast(job.coreLikeOp)) { - if (auto err = linkCoreWeights( - job.originalCoreId, job.emittedCoreId, result.usedWeightIndices, coreOp.getWeights(), coreOp.getOperation(), xbarsPerGroup)) + if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup)) return err; xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup); memory.recordCoreReport(job.emittedCoreId, result.reportRow); @@ -1391,14 +1404,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: for (size_t jobIndex : group) { const CoreEmissionJob& job = jobs[jobIndex]; const CoreEmissionResult& result = jobResults[jobIndex]; - auto coreBatchOp = cast(job.coreLikeOp); json::Array xbarsPerGroup; - if (auto err = linkCoreWeights(job.originalCoreId, - job.emittedCoreId, - result.usedWeightIndices, - coreBatchOp.getWeights(), - coreBatchOp.getOperation(), - xbarsPerGroup)) + if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup)) return err; xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup); reportedCoreIds.push_back(static_cast(job.emittedCoreId)); diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp index 5bd7ebf..719e52a 100644 --- a/src/PIM/Compiler/PimWeightEmitter.cpp +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -1,25 +1,16 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include -#include #include "Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" -#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace llvm; using namespace mlir; @@ -27,240 +18,72 @@ using namespace mlir; namespace onnx_mlir { namespace { -struct DenseWeightView { - DenseElementsAttr denseAttr; - SmallVector shape; - SmallVector strides; - int64_t offset = 0; -}; - -FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) { - SmallVector viewOps; - mlir::Value current = weight; - memref::GetGlobalOp getGlobalOp; - - while (true) { - Operation* defOp = current.getDefiningOp(); - if (!defOp) - return failure(); - if ((getGlobalOp = dyn_cast(defOp))) - break; - if (auto subview = dyn_cast(defOp)) { - if (!hasAllStaticSubviewParts(subview)) - return failure(); - viewOps.push_back(subview); - current = subview.getSource(); - continue; - } - if (auto cast = dyn_cast(defOp)) { - current = cast.getSource(); - continue; - } - if (auto collapse = dyn_cast(defOp)) { - auto srcType = dyn_cast(collapse.getSrc().getType()); - auto resultType = dyn_cast(collapse.getResult().getType()); - if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape()) - return failure(); - viewOps.push_back(collapse); - current = collapse.getSrc(); - continue; - } - if (auto expand = dyn_cast(defOp)) { - auto srcType = dyn_cast(expand.getSrc().getType()); - auto resultType = dyn_cast(expand.getResult().getType()); - if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape()) - return failure(); - viewOps.push_back(expand); - current = expand.getSrc(); - continue; - } - return failure(); - } - - auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); - if (!globalOp || !globalOp.getInitialValue()) - return failure(); - - auto denseAttr = dyn_cast(*globalOp.getInitialValue()); - if (!denseAttr) - return failure(); - - DenseWeightView view; - view.denseAttr = denseAttr; - view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); - view.strides = computeRowMajorStrides(view.shape); - - for (Operation* viewOp : llvm::reverse(viewOps)) { - if (auto subview = dyn_cast(viewOp)) { - SmallVector nextStrides; - nextStrides.reserve(subview.getStaticStrides().size()); - for (auto [offset, stride, sourceStride] : - llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) { - view.offset += offset * sourceStride; - nextStrides.push_back(stride * sourceStride); - } - view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); - view.strides = std::move(nextStrides); - continue; - } - - // Collapse/expand are accepted only as contiguous static reshapes of a - // dense global view, so a row-major stride recomputation preserves layout. - if (auto collapse = dyn_cast(viewOp)) { - if (view.strides != computeRowMajorStrides(view.shape)) - return failure(); - auto resultType = cast(collapse.getResult().getType()); - view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); - view.strides = computeRowMajorStrides(view.shape); - continue; - } - - if (auto expand = dyn_cast(viewOp)) { - if (view.strides != computeRowMajorStrides(view.shape)) - return failure(); - auto resultType = cast(expand.getResult().getType()); - view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); - view.strides = computeRowMajorStrides(view.shape); - continue; - } - } - - return view; -} - -SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { - SmallVector coreLikeOps; - for (Operation& op : funcOp.getBody().front()) - if (dyn_cast(&op) || dyn_cast(&op)) - coreLikeOps.push_back(&op); - return coreLikeOps; -} - } // namespace -llvm::DenseMap> -createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { - ModuleOp moduleOp = funcOp->getParentOfType(); +llvm::DenseMap> +createAndPopulateWeightFolder(ArrayRef requests, StringRef outputDirPath) { auto coreWeightsDirPath = outputDirPath + "/weights"; auto error = sys::fs::create_directory(coreWeightsDirPath); assert(!error && "Error creating weights directory"); size_t indexFileName = 0; int64_t xbarSize = crossbarSize.getValue(); - llvm::DenseMap> mapCoreWeightToFileName; - llvm::DenseMap mapGlobalOpToFileName; - llvm::DenseMap mapWeightValueToFileName; + llvm::DenseMap> mapCoreWeightToFileName; + llvm::SmallVector, 16> materializedWeights; - SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); + auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string { + if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; }); + it != materializedWeights.end()) + return it->second; - for (Operation* op : coreLikeOps) { - auto processWeight = [&](Operation* ownerOp, - mlir::Value weight, - size_t weightIndex, - size_t coreId) -> LogicalResult { - auto weightView = resolveDenseWeightView(moduleOp, weight); - if (failed(weightView)) { - ownerOp->emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex)); - assert(succeeded(weightView) && "Weight is not from a dense memref.global view"); - } + auto globalOp = weightView.globalOp; + auto denseAttr = mlir::dyn_cast(*globalOp.getInitialValue()); + assert(denseAttr && "Weight global must have dense initial value"); - if (mapCoreWeightToFileName[coreId].contains(weight)) - return success(); + ArrayRef shape = weightView.shape; + assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); + int64_t numRows = shape[0]; + int64_t numCols = shape[1]; + assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); - if (auto weightFile = mapWeightValueToFileName.find(weight); weightFile != mapWeightValueToFileName.end()) { - mapCoreWeightToFileName[coreId].insert({weight, weightFile->second}); - return success(); - } + size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType()); - auto getGlobalOp = weight.getDefiningOp(); - auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {}; - if (globalOp && mapGlobalOpToFileName.contains(globalOp)) { - auto& fileName = mapGlobalOpToFileName[globalOp]; - mapWeightValueToFileName[weight] = fileName; - mapCoreWeightToFileName[coreId].insert({weight, fileName}); - return success(); - } - - DenseElementsAttr denseAttr = weightView->denseAttr; - ArrayRef shape = weightView->shape; - assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); - int64_t numRows = shape[0]; - int64_t numCols = shape[1]; - assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); - - size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType()); - - std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; - auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); - std::error_code errorCode; - raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); - if (errorCode) { - errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; - assert(errorCode); - } - - uint64_t zero = 0; - for (int64_t row = 0; row < xbarSize; row++) { - for (int64_t col = 0; col < xbarSize; col++) { - if (row < numRows && col < numCols) { - int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1]; - APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); - uint64_t word = bits.getZExtValue(); - weightFileStream.write(reinterpret_cast(&word), elementByteWidth); - } - else { - weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); - } - } - } - - weightFileStream.close(); - if (globalOp) - mapGlobalOpToFileName.insert({globalOp, newFileName}); - mapWeightValueToFileName[weight] = newFileName; - mapCoreWeightToFileName[coreId].insert({weight, newFileName}); - return success(); - }; - - auto processCoreLike = [&](auto coreLikeOp) { - auto usedIndices = getUsedWeightIndices(coreLikeOp); - for (unsigned index : usedIndices) { - if (index >= coreLikeOp.getWeights().size()) { - coreLikeOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); - assert(index < coreLikeOp.getWeights().size() && "Weight index is out of range"); - } - } - - if constexpr (std::is_same_v, pim::PimCoreOp>) { - size_t coreId = static_cast(coreLikeOp.getCoreId()); - for (unsigned index : usedIndices) - if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId))) - return failure(); - return success(); - } - else { - auto batchCoreIds = getBatchCoreIds(coreLikeOp); - SmallVector orderedCoreIds; - llvm::SmallSet seenCoreIds; - for (int32_t coreId : batchCoreIds) - if (seenCoreIds.insert(static_cast(coreId)).second) - orderedCoreIds.push_back(static_cast(coreId)); - - for (size_t coreId : orderedCoreIds) - for (unsigned index : usedIndices) - if (failed(processWeight(coreLikeOp, coreLikeOp.getWeights()[index], index, coreId))) - return failure(); - return success(); - } - }; - - if (auto coreOp = dyn_cast(op)) { - (void) processCoreLike(coreOp); - continue; + std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; + auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); + std::error_code errorCode; + raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); + if (errorCode) { + errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; + assert(errorCode); } - (void) processCoreLike(cast(op)); + uint64_t zero = 0; + for (int64_t row = 0; row < xbarSize; row++) { + for (int64_t col = 0; col < xbarSize; col++) { + if (row < numRows && col < numCols) { + int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1]; + APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); + uint64_t word = bits.getZExtValue(); + weightFileStream.write(reinterpret_cast(&word), elementByteWidth); + } + else { + weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); + } + } + } + + weightFileStream.close(); + materializedWeights.push_back({weightView, newFileName}); + return newFileName; + }; + + for (const WeightFileRequest& request : requests) { + auto& coreFiles = mapCoreWeightToFileName[request.coreId]; + coreFiles.reserve(request.weights.size()); + for (const ResolvedWeightView& weight : request.weights) + coreFiles.push_back(materializeWeight(weight)); } + return mapCoreWeightToFileName; } diff --git a/src/PIM/Compiler/PimWeightEmitter.hpp b/src/PIM/Compiler/PimWeightEmitter.hpp index a620028..2daa5ae 100644 --- a/src/PIM/Compiler/PimWeightEmitter.hpp +++ b/src/PIM/Compiler/PimWeightEmitter.hpp @@ -1,16 +1,23 @@ #pragma once #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" + namespace onnx_mlir { -llvm::DenseMap> -createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath); +struct WeightFileRequest { + size_t coreId = 0; + llvm::SmallVector weights; +}; + +llvm::DenseMap> +createAndPopulateWeightFolder(llvm::ArrayRef requests, llvm::StringRef outputDirPath); } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp index f7fcac9..1635ac3 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -30,6 +30,15 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& .getOutput(); } +Value allocateContiguousMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) { + if (succeeded(resolveContiguousAddress(memrefValue))) + return memrefValue; + + auto shapedType = cast(memrefValue.getType()); + auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); + return memref::AllocOp::create(rewriter, loc, contiguousType); +} + FailureOr getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) { if (isa(value.getType())) diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp index d4c5d48..de45384 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp @@ -6,6 +6,7 @@ namespace onnx_mlir::pim { mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); +mlir::Value allocateContiguousMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); llvm::FailureOr getBufferOrValue(mlir::RewriterBase& rewriter, mlir::Value value, diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 61ce65c..981f473 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -431,8 +431,11 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); + replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), *inputOpt, transposeOp.getPermutation(), *outputBufferOpt); + rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput); return success(); } }; @@ -473,9 +476,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt); + rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput); return success(); } }; @@ -512,9 +516,10 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter); + Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), contiguousLhs, contiguousRhs, *outputBufferOpt); + rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); return success(); } }; @@ -546,8 +551,9 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModelgetLoc(), rewriter); + Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); - replaceOpWithNewBufferizedOp(rewriter, op, outputBufferOpt->getType(), contiguousInput, *outputBufferOpt); + replaceOpWithNewBufferizedOp(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput); return success(); } }; diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 2252586..18e80ca 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -240,10 +240,10 @@ void SpatCompute::print(OpAsmPrinter& printer) { printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); - printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size(); if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) printer << " coreId " << coreIdAttr.getInt(); + printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size(); printer.printOptionalAttrDict((*this)->getAttrs(), {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); @@ -276,13 +276,13 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); - bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); - if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) - return failure(); - bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); if (hasCoreId && parser.parseInteger(coreId)) return failure(); + + bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); + if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) + return failure(); (void) crossbarWeightCount; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() @@ -365,13 +365,14 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) { printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); - printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size(); if (getNumResults() != 0) { printer << " shared_outs"; printBlockArgumentList(printer, outputArgs); } + printer << " crossbarWeights " << getComputeInstanceCrossbarUsage({getOperation(), 0, getLaneCount()}).size(); + if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) { printer << " coreIds "; printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); @@ -423,13 +424,13 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) if (parseBlockArgumentList(parser, outputArgs)) return failure(); - bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); - if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) - return failure(); - bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); + + bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); + if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) + return failure(); (void) crossbarWeightCount; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 8e1dcfc..9ac56c9 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -8,6 +8,7 @@ #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -346,6 +347,11 @@ private: if (isCoreWeightBlockArgument(operand)) continue; + if (auto vmmOp = dyn_cast(&op); + vmmOp && operandIndex == 0 && resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight())) { + continue; + } + auto resolvedAddress = resolveContiguousAddress(operand, knowledge); if (failed(resolvedAddress)) { diagnostics.report(&op, [&](Operation* illegalOp) {