better spatial IR compaction with better custom syntax, scf.for and
Validate Operations / validate-operations (push) Has been cancelled

spat.map
This commit is contained in:
NiccoloN
2026-05-06 12:21:58 +02:00
parent 285773fa55
commit b2dc9c38b6
12 changed files with 1442 additions and 274 deletions
@@ -111,7 +111,7 @@ static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t&
}
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
@@ -178,6 +178,43 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
}
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
int32_t laneCount,
IRMapping& mapper,
IRRewriter& rewriter) {
auto targetCoreIds = sendManyBatchOp.getTargetCoreIds();
for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) {
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount);
pim::PimSendBatchOp::create(rewriter,
sendManyBatchOp.getLoc(),
mapper.lookup(input),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)),
rewriter.getDenseI32ArrayAttr(targetSlice));
}
}
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
int32_t laneCount,
IRMapping& mapper,
IRRewriter& rewriter) {
auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds();
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount);
auto outputType = cast<ShapedType>(output.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
receiveManyBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, output),
rewriter.getDenseI32ArrayAttr(sourceSlice))
.getOutput();
mapper.map(output, received);
}
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
Value input = extractRowsOp.getInput();
RankedTensorType inputType;
@@ -226,6 +263,56 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.replaceOp(concatOp, concatenated);
}
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
wvmmOps.push_back(wvmmOp);
});
for (auto wvmmOp : wvmmOps) {
rewriter.setInsertionPoint(wvmmOp);
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
wvmmOp.getOutput().getType(),
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
wvmmOp.getInput(),
outputBuffer);
}
}
static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatMapOp> mapOps;
funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); });
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
Value replacement = input;
for (Operation& op : body.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
replacements.push_back(replacement);
}
rewriter.replaceOp(mapOp, replacements);
}
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
@@ -551,6 +638,7 @@ void SpatialToPimPass::runOnOperation() {
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
expandMapOps(funcOp, rewriter);
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
@@ -640,6 +728,32 @@ void SpatialToPimPass::runOnOperation() {
for (auto extractRowsOp : extractRowsOps)
lowerExtractRows(extractRowsOp, rewriter);
{
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
}
lowerRemainingSpatialMathOps(funcOp, rewriter);
RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
@@ -939,7 +1053,7 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
ValueRange(batchInputs));
coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
@@ -1000,6 +1114,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
continue;
}
if (auto sendManyBatchOp = dyn_cast<spatial::SpatChannelSendManyBatchOp>(op)) {
lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
@@ -1014,6 +1133,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
continue;
}
if (auto receiveManyBatchOp = dyn_cast<spatial::SpatChannelReceiveManyBatchOp>(op)) {
lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);