better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-22 06:56:39 +02:00
parent 6aaf1c0870
commit 43ed3914b8
13 changed files with 1433 additions and 1620 deletions
@@ -1,6 +1,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -97,20 +99,73 @@ static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveT
return success();
}
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
if (!result.hasOneUse())
return failure();
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
if (!returnOp)
return failure();
return result.getUses().begin()->getOperandNumber();
}
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
if (scale == 1)
return base;
auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult();
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType,
IRMapping& mapper) {
int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8;
SmallVector<int64_t> strides(destinationType.getRank(), 1);
ArrayRef<int64_t> shape = destinationType.getShape();
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
Value totalOffset;
Location loc = insertSlice.getLoc();
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
int64_t scale = strides[dim] * elementBytes;
Value scaledOffset;
if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute");
scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult();
}
else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
}
totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult()
: scaledOffset;
}
if (!totalOffset)
totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
return totalOffset;
}
} // namespace
LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; "
"materialize explicit communication before lowering to PIM");
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
if (computeBatchOp.getNumResults() == 0) {
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
}
else if (!inParallelOp) {
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
}
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
@@ -128,9 +183,24 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Value> hostOutputTensors;
if (computeBatchOp.getNumResults() != 0) {
hostOutputTensors.resize(computeBatchOp.getNumResults());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
hostOutputTensors[resultIndex] = state.outputTensors[*returnOperandIndex](rewriter, loc);
result.replaceAllUsesWith(hostOutputTensors[resultIndex]);
}
}
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) {
unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
@@ -183,6 +253,38 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
unsigned firstOutputArg = computeBatchOp.getOutputArgument(0).getArgNumber();
for (Operation& nestedOp : parallelOp.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
if (!insertSlice)
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &oldBlock)
return insertSlice.emitOpError("expected compute_batch output block argument destination");
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg;
if (resultIndex >= hostOutputTensors.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output");
Value mappedSource = mapper.lookup(insertSlice.getSource());
auto hostTarget = hostOutputTensors[resultIndex];
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(),
hostTarget.getType(),
hostTargetOffset,
zeroOffset,
hostTarget,
mappedSource,
getTensorSizeInBytesAttr(rewriter, mappedSource));
}
continue;
}
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
if (failed(targetCoreIds))
@@ -6,7 +6,6 @@ add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp
ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp
@@ -1,42 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -1,11 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -141,152 +141,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
};
// Turns runtime constants consumed by compute regions into private globals and local loads.
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
loc,
constantOp->getParentOfType<ModuleOp>(),
"const",
memRefType,
constantOp.getValueAttr(),
rewriter.getUnitAttr());
std::string argName = globalOp.getSymName().str();
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(rewriter,
spatComputeBatch.getOperation(),
BBArgIndex,
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
}
else {
{
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
Value hostConstant = constantOp.getResult();
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
}
else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
constUses.set(hostConstant);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
constUses.set(hostConstant);
}
}
}
if (constantOp->use_empty())
rewriter.eraseOp(constantOp);
return success();
}
};
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
@@ -363,7 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
} // namespace
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(
patterns.getContext());
}
@@ -14,7 +14,6 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
@@ -28,7 +27,6 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
@@ -67,6 +65,7 @@ private:
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void eraseOpsToRemove();
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
};
@@ -268,13 +267,7 @@ void SpatialToPimPass::runOnOperation() {
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
signalPassFailure();
return;
}
eraseOpsToRemove();
RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns);
@@ -399,6 +392,13 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
operationsToRemove.push_back(op);
}
void SpatialToPimPass::eraseOpsToRemove() {
for (Operation* op : operationsToRemove) {
op->dropAllUses();
op->erase();
}
}
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
} // namespace onnx_mlir