This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
|
||||
@@ -24,113 +28,132 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
|
||||
IRRewriter rewriter(scalarCore.getContext());
|
||||
SmallVector<Operation*> batchOps;
|
||||
scalarCore.walk([&](Operation* op) {
|
||||
if (isa<pim::PimSendBatchOp,
|
||||
pim::PimSendTensorBatchOp,
|
||||
pim::PimReceiveBatchOp,
|
||||
pim::PimReceiveTensorBatchOp,
|
||||
pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
batchOps.push_back(op);
|
||||
}
|
||||
});
|
||||
static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||
pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
OperationFolder& constantFolder) {
|
||||
Block& oldBlock = coreBatchOp.getBody().front();
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightCount = coreBatchOp.getWeights().size();
|
||||
|
||||
for (Operation* op : batchOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
IRMapping mapper;
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (blockArg.getType().isIndex()) {
|
||||
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (argIndex <= weightCount) {
|
||||
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]);
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t inputIndex = argIndex - 1 - weightCount;
|
||||
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
|
||||
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
sendBatchOp.getLoc(),
|
||||
sendBatchOp.getInput(),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
rewriter.eraseOp(op);
|
||||
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||
pim::PimSendOp::create(
|
||||
builder,
|
||||
sendBatchOp.getLoc(),
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||
pim::PimSendTensorOp::create(
|
||||
rewriter,
|
||||
builder,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
sendTensorBatchOp.getInput(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
rewriter.eraseOp(op);
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(rewriter,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
receiveBatchOp.getOutputBuffer(),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||
auto scalarReceive = pim::PimReceiveOp::create(
|
||||
builder,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
|
||||
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||
rewriter,
|
||||
builder,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
receiveTensorBatchOp.getOutput().getType(),
|
||||
receiveTensorBatchOp.getOutputBuffer(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
memcpBatchOp.getDeviceTarget(),
|
||||
memcpBatchOp.getHostSource(),
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
rewriter.replaceOp(op, scalarCopy->getResults());
|
||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
|
||||
builder,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
|
||||
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
|
||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||
mapper.lookup(memcpBatchOp.getHostSource()),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||
ArrayRef<unsigned> lanes,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
assert(!lanes.empty() && "expected at least one batch lane");
|
||||
|
||||
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
||||
OpBuilder builder(scratchModule->getContext());
|
||||
OperationFolder constantFolder(scratchModule->getContext());
|
||||
builder.setInsertionPointToStart(scratchModule->getBody());
|
||||
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
||||
SmallVector<Value> laneWeights;
|
||||
laneWeights.reserve(weightsPerLane);
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
||||
|
||||
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
|
||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||
auto scalarCore = pim::PimCoreOp::create(
|
||||
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
|
||||
int32_t coreId = coreIds[lanes.front()];
|
||||
for (unsigned lane : lanes)
|
||||
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
|
||||
|
||||
auto scalarCore =
|
||||
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
|
||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||
IRMapping mapper;
|
||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
|
||||
for (unsigned lane : lanes)
|
||||
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
|
||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
|
||||
return callback(scalarCore);
|
||||
}
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -9,5 +9,8 @@ namespace onnx_mlir {
|
||||
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||
llvm::ArrayRef<unsigned> lanes,
|
||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -41,15 +41,23 @@ using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||
if (elementType.isIndex())
|
||||
return sizeof(int64_t);
|
||||
if (elementType.isIntOrFloat())
|
||||
return elementType.getIntOrFloatBitWidth() / 8;
|
||||
llvm_unreachable("unsupported shaped element type");
|
||||
}
|
||||
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
||||
}
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
||||
MemEntry memEntry = {0, allocSize};
|
||||
return &memEntries.emplace_back(memEntry, value).first;
|
||||
}
|
||||
@@ -398,20 +406,28 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
|
||||
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
|
||||
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
|
||||
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("ld",
|
||||
addressOf(loadOp.getDeviceTarget(), knowledge),
|
||||
loadOp.getDeviceTargetOffset(),
|
||||
*deviceTargetOffset,
|
||||
addressOf(loadOp.getHostSource(), knowledge),
|
||||
loadOp.getHostSourceOffset(),
|
||||
*hostSourceOffset,
|
||||
loadOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
|
||||
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
|
||||
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
|
||||
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("st",
|
||||
addressOf(storeOp.getHostTarget(), knowledge),
|
||||
storeOp.getHostTargetOffset(),
|
||||
*hostTargetOffset,
|
||||
addressOf(storeOp.getDeviceSource(), knowledge),
|
||||
storeOp.getDeviceSourceOffset(),
|
||||
*deviceSourceOffset,
|
||||
storeOp.getSize());
|
||||
}
|
||||
|
||||
@@ -426,8 +442,9 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp(
|
||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
|
||||
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
@@ -439,7 +456,9 @@ void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
|
||||
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -728,12 +747,19 @@ std::string getMemorySizeAsString(size_t size) {
|
||||
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
if (!coreOp)
|
||||
return;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
@@ -795,6 +821,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
/// fully resolved before the JSON instructions are emitted.
|
||||
/// Returns the number of emitted instructions, or -1 on failure.
|
||||
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
|
||||
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
|
||||
if (!coreOp)
|
||||
return std::nullopt;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
};
|
||||
size_t processedOperations = 0;
|
||||
auto result =
|
||||
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
@@ -814,8 +849,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
||||
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
|
||||
auto weightIndex = resolveWeightIndex(vmmOp);
|
||||
if (!weightIndex)
|
||||
return failure();
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
|
||||
}
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
@@ -1004,10 +1043,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
reportedCoreIds.reserve(batchCoreIds.size());
|
||||
MemoryReportRow batchRow;
|
||||
std::optional<MemoryReportRow> batchPerCoreRow;
|
||||
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
|
||||
SmallVector<size_t> orderedOriginalCoreIds;
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
|
||||
if (inserted)
|
||||
orderedOriginalCoreIds.push_back(originalCoreId);
|
||||
it->second.push_back(lane);
|
||||
}
|
||||
|
||||
for (size_t originalCoreId : orderedOriginalCoreIds) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||
MemoryReportRow laneRow;
|
||||
|
||||
@@ -128,12 +128,20 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
|
||||
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
if (!coreOp)
|
||||
return;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user