better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-05-22 06:56:39 +02:00
parent 6aaf1c0870
commit 43ed3914b8
13 changed files with 1433 additions and 1620 deletions
@@ -299,10 +299,11 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
if in_path.contains(&waiting_for) { if in_path.contains(&waiting_for) {
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap(); let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
let cycle = &path[cycle_start..]; let cycle = &path[cycle_start..];
let format_core = |core: &i32| (core - 1).to_string();
let cycle_str = cycle let cycle_str = cycle
.iter() .iter()
.map(|c| c.to_string()) .map(format_core)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(" -> "); .join(" -> ");
@@ -311,19 +312,19 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
.copied() .copied()
.chain(std::iter::once(waiting_for)) .chain(std::iter::once(waiting_for))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for); let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
let states_msg = cycle let states_msg = cycle
.iter() .iter()
.filter_map(|core| { .filter_map(|core| {
states.get(core).map(|state| match state { states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => { CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core, size, target) format!("core {} send {}B -> {}", core - 1, size, target - 1)
} }
CoreState::ReceivingFrom(source, size) => { CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core, size, source) format!("core {} recv {}B <- {}", core - 1, size, source - 1)
} }
CoreState::Working => format!("core {} working", core), CoreState::Working => format!("core {} working", core - 1),
CoreState::Halted => format!("core {} halted", core), CoreState::Halted => format!("core {} halted", core - 1),
}) })
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
+41 -7
View File
@@ -28,23 +28,47 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
return laneCoreIds; return laneCoreIds;
} }
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
if (Value mapped = mapper.lookupOrNull(value))
return mapped;
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
}
Operation* definingOp = value.getDefiningOp();
assert(definingOp && "expected captured value to be defined by an operation");
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
for (Value operand : definingOp->getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
Operation* cloned = builder.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static void cloneScalarizedLaneBody(OpBuilder& builder, static void cloneScalarizedLaneBody(OpBuilder& builder,
pim::PimCoreBatchOp coreBatchOp, pim::PimCoreBatchOp coreBatchOp,
unsigned lane, unsigned lane,
OperationFolder& constantFolder) { OperationFolder& constantFolder) {
Block& oldBlock = coreBatchOp.getBody().front(); Block& oldBlock = coreBatchOp.getBody().front();
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount()); size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightCount = coreBatchOp.getWeights().size(); size_t weightCount = coreBatchOp.getWeights().size();
IRMapping mapper; IRMapping mapper;
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) { for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
if (blockArg.getType().isIndex()) { if (blockArg.getType().isIndex()) {
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder)); mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
continue; continue;
} }
if (argIndex <= weightCount) { if (argIndex <= weightCount) {
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]); auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
continue; continue;
} }
@@ -57,8 +81,10 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
if (isa<pim::PimHaltOp>(op)) if (isa<pim::PimHaltOp>(op))
continue; continue;
for (Value operand : op.getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) { if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
pim::PimSendOp::create( pim::PimSendOp::create(
builder, builder,
sendBatchOp.getLoc(), sendBatchOp.getLoc(),
@@ -78,7 +104,6 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
} }
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) { if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
auto scalarReceive = pim::PimReceiveOp::create( auto scalarReceive = pim::PimReceiveOp::create(
builder, builder,
receiveBatchOp.getLoc(), receiveBatchOp.getLoc(),
@@ -106,8 +131,8 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
builder, builder,
memcpBatchOp.getLoc(), memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(), memcpBatchOp.getOutput().getType(),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder), getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder), getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
mapper.lookup(memcpBatchOp.getDeviceTarget()), mapper.lookup(memcpBatchOp.getDeviceTarget()),
mapper.lookup(memcpBatchOp.getHostSource()), mapper.lookup(memcpBatchOp.getHostSource()),
memcpBatchOp.getSizeAttr()); memcpBatchOp.getSizeAttr());
@@ -141,7 +166,16 @@ LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
auto scalarCore = auto scalarCore =
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId)); pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); SmallVector<Type> weightTypes;
SmallVector<Location> weightLocs;
weightTypes.reserve(weights.size());
weightLocs.reserve(weights.size());
for (Value weight : weights) {
weightTypes.push_back(weight.getType());
weightLocs.push_back(weight.getLoc());
}
Block* block =
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
for (unsigned lane : lanes) for (unsigned lane : lanes)
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder); cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
@@ -1,6 +1,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
@@ -97,20 +99,73 @@ static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveT
return success(); return success();
} }
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
if (!result.hasOneUse())
return failure();
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
if (!returnOp)
return failure();
return result.getUses().begin()->getOperandNumber();
}
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
if (scale == 1)
return base;
auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult();
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType,
IRMapping& mapper) {
int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8;
SmallVector<int64_t> strides(destinationType.getRank(), 1);
ArrayRef<int64_t> shape = destinationType.getShape();
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
Value totalOffset;
Location loc = insertSlice.getLoc();
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
int64_t scale = strides[dim] * elementBytes;
Value scaledOffset;
if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute");
scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult();
}
else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
}
totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult()
: scaledOffset;
}
if (!totalOffset)
totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
return totalOffset;
}
} // namespace } // namespace
LogicalResult LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) { lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeBatchOp.getLoc(); Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front(); Block& oldBlock = computeBatchOp.getBody().front();
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; "
"materialize explicit communication before lowering to PIM");
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator()); auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (!oldYield || oldYield.getNumOperands() != 0) auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield"); if (computeBatchOp.getNumResults() == 0) {
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
}
else if (!inParallelOp) {
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
}
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId); SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
@@ -128,9 +183,24 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())}); {static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Value> hostOutputTensors;
if (computeBatchOp.getNumResults() != 0) {
hostOutputTensors.resize(computeBatchOp.getNumResults());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
hostOutputTensors[resultIndex] = state.outputTensors[*returnOperandIndex](rewriter, loc);
result.replaceAllUsesWith(hostOutputTensors[resultIndex]);
}
}
SmallVector<Type> blockArgTypes; SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs; SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) { unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
blockArgTypes.push_back(arg.getType()); blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc()); blockArgLocs.push_back(arg.getLoc());
} }
@@ -183,6 +253,38 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa<spatial::SpatYieldOp>(op)) if (isa<spatial::SpatYieldOp>(op))
continue; continue;
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
unsigned firstOutputArg = computeBatchOp.getOutputArgument(0).getArgNumber();
for (Operation& nestedOp : parallelOp.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
if (!insertSlice)
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &oldBlock)
return insertSlice.emitOpError("expected compute_batch output block argument destination");
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg;
if (resultIndex >= hostOutputTensors.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output");
Value mappedSource = mapper.lookup(insertSlice.getSource());
auto hostTarget = hostOutputTensors[resultIndex];
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(),
hostTarget.getType(),
hostTargetOffset,
zeroOffset,
hostTarget,
mappedSource,
getTensorSizeInBytesAttr(rewriter, mappedSource));
}
continue;
}
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) { if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds()); FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
if (failed(targetCoreIds)) if (failed(targetCoreIds))
@@ -6,7 +6,6 @@ add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp Common.cpp
ComputeLikeRegionUtils.cpp ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp CoreLoweringPatterns.cpp
@@ -1,42 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -1,11 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -141,152 +141,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
} }
}; };
// Turns runtime constants consumed by compute regions into private globals and local loads.
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
loc,
constantOp->getParentOfType<ModuleOp>(),
"const",
memRefType,
constantOp.getValueAttr(),
rewriter.getUnitAttr());
std::string argName = globalOp.getSymName().str();
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(rewriter,
spatComputeBatch.getOperation(),
BBArgIndex,
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
}
else {
{
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
Value hostConstant = constantOp.getResult();
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
}
else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
constUses.set(hostConstant);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
constUses.set(hostConstant);
}
}
}
if (constantOp->use_empty())
rewriter.eraseOp(constantOp);
return success();
}
};
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly. // Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> { struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -363,7 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
} // namespace } // namespace
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) { void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>( patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(
patterns.getContext()); patterns.getContext());
} }
@@ -14,7 +14,6 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
@@ -28,7 +27,6 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
@@ -67,6 +65,7 @@ private:
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void markOpToRemove(Operation* op); void markOpToRemove(Operation* op);
void eraseOpsToRemove();
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
}; };
@@ -268,13 +267,7 @@ void SpatialToPimPass::runOnOperation() {
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState); replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
eraseOpsToRemove();
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
signalPassFailure();
return;
}
RewritePatternSet finalTensorPackingPatterns(ctx); RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns); populateTensorPackingPatterns(finalTensorPackingPatterns);
@@ -399,6 +392,13 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
operationsToRemove.push_back(op); operationsToRemove.push_back(op);
} }
void SpatialToPimPass::eraseOpsToRemove() {
for (Operation* op : operationsToRemove) {
op->dropAllUses();
op->erase();
}
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); } std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir
+15 -1
View File
@@ -1,4 +1,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
@@ -6,6 +8,7 @@
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -40,7 +43,18 @@ static bool isDefinedInsideRegion(Value value, Region& region) {
static bool isConstantExternalValue(Value value) { static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp(); Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>(); if (!definingOp)
return false;
if (definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(definingOp);
if (!getGlobalOp)
return false;
auto moduleOp = definingOp->getParentOfType<ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
return globalOp && globalOp.getConstant();
} }
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) { static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
@@ -120,6 +120,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value)) if (value == laneArg || isConstantIndexLike(value))
return true; return true;
auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
if (extractOp) {
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();
auto denseAttr = constantTensor ? dyn_cast<DenseIntElementsAttr>(constantTensor.getValue()) : nullptr;
if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1)
return false;
return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg);
}
auto addOp = value.getDefiningOp<arith::AddIOp>(); auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp) if (!addOp)
return false; return false;
File diff suppressed because it is too large Load Diff
@@ -267,212 +267,6 @@ bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
} }
struct BatchYieldInfo {
Value yieldedValue;
tensor::ParallelInsertSliceOp insertSlice;
};
static bool isHostOnlyBatchResultUser(Operation* user) {
return isa<func::ReturnOp,
spatial::SpatConcatOp,
tensor::ExtractSliceOp,
tensor::CastOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp>(user);
}
static FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> collectBatchYieldInfo(SpatComputeBatch batchOp) {
Block& block = batchOp.getBody().front();
auto inParallel = dyn_cast<spatial::SpatInParallelOp>(block.getTerminator());
if (!inParallel)
return failure();
DenseMap<BlockArgument, BatchYieldInfo> batchYieldByOutputArg;
for (Operation& op : inParallel.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSlice)
return failure();
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &block)
return failure();
batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice};
}
return batchYieldByOutputArg;
}
static FailureOr<SpatComputeBatch> cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) {
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return failure();
Block& oldBlock = batchOp.getBody().front();
rewriter.setInsertionPoint(batchOp);
auto newBatch = SpatComputeBatch::create(rewriter,
batchOp.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(batchOp.getLaneCount()),
batchOp.getWeights(),
batchOp.getInputs());
newBatch.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
blockArgTypes.push_back(batchOp.getLaneArgument().getType());
blockArgLocs.push_back(batchOp.getLaneArgument().getLoc());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) {
blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType());
blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc());
}
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) {
blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType());
blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc());
}
Block* newBlock =
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex));
for (Operation& op : oldBlock.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(oldResult, newResult);
}
return newBatch;
}
static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
for (auto batchOp : batches) {
if (batchOp.getNumResults() == 0)
continue;
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return batchOp.emitOpError("missing coreIds while materializing batch result communication");
FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> batchYieldInfo = collectBatchYieldInfo(batchOp);
if (failed(batchYieldInfo))
return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body");
FailureOr<SpatComputeBatch> newBatch = cloneBatchAsResultless(batchOp, rewriter);
if (failed(newBatch))
return batchOp.emitOpError("failed to clone resultful compute_batch as resultless");
Block& oldBlock = batchOp.getBody().front();
Block& newBlock = newBatch->getBody().front();
IRMapping mapper;
mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex));
auto oldIt = oldBlock.begin();
auto newIt = newBlock.begin();
for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt)
for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults()))
mapper.map(oldResult, newResult);
SmallVector<int32_t> sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
rewriter.setInsertionPointToEnd(&newBlock);
for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) {
BlockArgument outputArg = batchOp.getOutputArgument(resultIndex);
auto yieldInfoIt = batchYieldInfo->find(outputArg);
if (yieldInfoIt == batchYieldInfo->end())
return batchOp.emitOpError(
"missing yielded value for compute_batch result during communication materialization");
Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue);
DenseMap<int32_t, SmallVector<OpOperand*>> computeUsesByTargetCore;
SmallVector<OpOperand*> hostUses;
for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) {
if (auto computeOp = dyn_cast<SpatCompute>(use.getOwner())) {
auto coreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
return batchOp.emitOpError("compute user of compute_batch result is missing coreId");
computeUsesByTargetCore[static_cast<int32_t>(coreIdAttr.getInt())].push_back(&use);
continue;
}
if (isHostOnlyBatchResultUser(use.getOwner())) {
hostUses.push_back(&use);
continue;
}
return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization")
<< ": " << use.getOwner()->getName();
}
auto createReceiveForUses = [&](ArrayRef<OpOperand*> uses, ArrayRef<int32_t> targetCoreIds) -> LogicalResult {
if (uses.empty())
return success();
SmallVector<int64_t> channelIds;
channelIds.reserve(sourceCoreIds.size());
for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds)
channelIds.push_back(nextChannelId++);
SmallVector<Value> sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
SmallVector<Value> sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
SmallVector<Value> sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter,
batchOp.getLoc(),
sendChannelIdValues,
sendSourceCoreIdValues,
sendTargetCoreIdValues,
mappedYieldedValue);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(newBatch->getOperation());
SmallVector<Value> receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
SmallVector<Value> receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
SmallVector<Value> receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter,
batchOp.getLoc(),
batchOp.getResult(resultIndex).getType(),
receiveChannelIdValues,
receiveSourceCoreIdValues,
receiveTargetCoreIdValues);
for (OpOperand* use : uses)
use->set(received.getOutput());
rewriter.setInsertionPointToEnd(&newBlock);
return success();
};
for (auto& [targetCoreId, uses] : computeUsesByTargetCore) {
SmallVector<int32_t> targetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), targetCoreId);
if (failed(createReceiveForUses(uses, targetCoreIds)))
return failure();
}
if (!hostUses.empty()) {
SmallVector<int32_t> hostTargetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), 0);
if (failed(createReceiveForUses(hostUses, hostTargetCoreIds)))
return failure();
}
}
rewriter.setInsertionPointToEnd(&newBlock);
spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {});
rewriter.eraseOp(batchOp);
}
return success();
}
void rebatchEquivalentComputes(func::FuncOp funcOp) { void rebatchEquivalentComputes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext()); OperationFolder constantFolder(funcOp.getContext());
@@ -731,11 +525,6 @@ LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextC
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops"); ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
cleanupDeadPackingOps(funcOp); cleanupDeadPackingOps(funcOp);
} }
{
ScopedMergePhaseTimer timer("materialize-batch-result-communication");
if (failed(materializeBatchResultCommunication(funcOp, nextChannelId)))
return failure();
}
return success(); return success();
} }
@@ -1,7 +1,6 @@
#include "mlir/IR/Threading.h" #include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
@@ -20,7 +19,6 @@ struct ScheduledTask {
size_t processor = std::numeric_limits<size_t>::max(); size_t processor = std::numeric_limits<size_t>::max();
Time startTime = 0; Time startTime = 0;
Time endTime = 0; Time endTime = 0;
size_t slot = 0;
}; };
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) { std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
@@ -244,7 +242,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
llvm::report_fatal_error(llvm::StringRef(message)); llvm::report_fatal_error(llvm::StringRef(message));
} }
schedules[task] = {bestProcessor, bestEst, bestEft, 0}; schedules[task] = {bestProcessor, bestEst, bestEft};
scheduled[task] = true; scheduled[task] = true;
++scheduledCount; ++scheduledCount;
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage); processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
@@ -278,7 +276,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
// 5. Check if equal schedule in two level // 5. Check if equal schedule in two level
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass; llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) { for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
for (size_t controlProcessor = currentProcessor + 1; controlProcessor < processorCount; ++controlProcessor) { for (size_t controlProcessor = currentProcessor; controlProcessor < processorCount; ++controlProcessor) {
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size()) if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
continue; continue;
auto& currentTasks = tasksByProcessor[currentProcessor]; auto& currentTasks = tasksByProcessor[currentProcessor];
@@ -288,7 +286,8 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) { for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance; const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance; const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
if (currentComputeInstance.op != controlComputeInstance.op) { if (currentComputeInstance.op != controlComputeInstance.op
|| currentComputeInstance.laneCount != controlComputeInstance.laneCount) {
equalSchedule = false; equalSchedule = false;
break; break;
} }
@@ -300,7 +299,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
} }
} }
} }
{ /*{
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n"; llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
std::vector<bool> visited(processorCount, false); std::vector<bool> visited(processorCount, false);
size_t uniqueClassCount = 0; size_t uniqueClassCount = 0;
@@ -330,7 +329,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n"; llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
llvm::dbgs() << "--------------------------------------\n"; llvm::dbgs() << "--------------------------------------\n";
} }*/
// 6. Populate Final Result // 6. Populate Final Result
MergeScheduleResult result; MergeScheduleResult result;