huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled

remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
NiccoloN
2026-05-12 10:35:44 +02:00
parent feaff820e1
commit 909c4acfdd
84 changed files with 4048 additions and 3310 deletions
@@ -0,0 +1,224 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
return coreIds;
}
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
Value input = mapper.lookup(sendTensorBatchOp.getInput());
if (auto concatOp = input.getDefiningOp<tensor::ConcatOp>())
if (concatOp.getDim() == 0)
if (Value packedInput =
createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc()))
input = packedInput;
pim::PimSendTensorBatchOp::create(
rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds));
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received);
}
} // namespace
LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty())
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
rewriter.setInsertionPointAfter(computeBatchOp);
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
loc,
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
ValueRange(batchWeights),
ValueRange(batchInputs));
coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
Block* newBlock =
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
newArg,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput();
mapper.map(oldArg, copied);
}
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
if (auto mapped = mapper.lookupOrNull(capturedTensor))
return mapped;
auto capturedType = cast<ShapedType>(capturedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
capturedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, capturedTensor))
.getOutput();
mapper.map(capturedTensor, copied);
return copied;
};
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) {
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr());
continue;
}
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
clonedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput();
mapper.map(toTensorOp.getResult(), copied);
continue;
}
}
for (Value operand : op.getOperands()) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
materializeCapturedTensor(operand);
}
Operation* cloned = rewriter.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
rewriter.setInsertionPointToEnd(newBlock);
PimHaltOp::create(rewriter, loc);
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,10 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
namespace onnx_mlir {
mlir::LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -4,8 +4,16 @@ add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp
Patterns.cpp
ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp
PhaseVerification.cpp
ReturnPathNormalization.cpp
TensorPackingPatterns.cpp
EXCLUDE_FROM_OM_LIBS
@@ -0,0 +1,136 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create(rewriter,
op.getLoc(),
op.getInput(),
getTensorSizeInBytesAttr(rewriter, op.getInput()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override {
if (op->use_empty()) {
rewriter.eraseOp(op);
return success();
}
auto outputType = cast<ShapedType>(op.getResult().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received = pim::PimReceiveOp::create(rewriter,
op.getLoc(),
op.getResult().getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, op.getResult()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(op.getTargetCoreIds().size());
for (int32_t targetCoreId : op.getTargetCoreIds())
targetCoreIds.push_back(toPimCoreId(targetCoreId));
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(op.getSourceCoreIds().size());
for (int32_t sourceCoreId : op.getSourceCoreIds())
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received =
pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override {
auto inputType = cast<RankedTensorType>(op.getInput().getType());
SmallVector<Value> replacements;
replacements.reserve(op.getNumResults());
for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) {
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
rewriter.getIndexAttr(inputType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacements.push_back(
tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides)
.getResult());
}
rewriter.replaceOp(op, replacements);
return success();
}
};
struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override {
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value concatenated =
pim::PimConcatOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer)
.getOutput();
rewriter.replaceOp(op, concatenated);
return success();
}
};
} // namespace
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<ChannelSendLowering,
ChannelReceiveLowering,
ChannelSendTensorLowering,
ChannelReceiveTensorLowering,
ExtractRowsLowering,
ConcatLowering>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -0,0 +1,42 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -0,0 +1,44 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) {
auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional<unsigned> {
if (inputCount == 0)
return std::nullopt;
unsigned inputBegin = op->getNumOperands() - inputCount;
if (operandNumber < inputBegin)
return std::nullopt;
return operandNumber - inputBegin;
};
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
return getInputIndex(owner, compute.getInputs().size());
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
return getInputIndex(owner, computeBatch.getInputs().size());
return std::nullopt;
}
void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
Operation* owner,
unsigned inputIndex,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = body.getArgument(inputIndex);
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(inputIndex);
rewriter.finalizeOpModification(owner);
}
} // namespace onnx_mlir
@@ -0,0 +1,17 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>
namespace onnx_mlir {
std::optional<unsigned> getDirectComputeLikeInputIndex(mlir::Operation* owner, unsigned operandNumber);
void replaceAndEraseDirectComputeLikeInput(mlir::PatternRewriter& rewriter,
mlir::Operation* owner,
unsigned inputIndex,
mlir::Value replacement);
} // namespace onnx_mlir
@@ -0,0 +1,213 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
return static_cast<int32_t>(fallbackCoreId++);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (requireReturnUse
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
}))
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return false;
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
computeOp.getResult(0).replaceAllUsesWith(replacement);
return true;
}
} // namespace
void markOpToRemove(CoreLoweringState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
return success();
SmallVector<Operation*> helperChain;
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
return success();
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
if (!receiveOp || blockArg.use_empty())
continue;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received = PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp);
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
if (result.use_empty())
continue;
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
ReturnPathLoweringResult returnPathResult =
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
if (returnPathResult == ReturnPathLoweringResult::Failure)
return failure();
if (returnPathResult == ReturnPathLoweringResult::Handled)
continue;
auto resultUses = result.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<spatial::SpatChannelSendOp>(resultUser))
continue;
}
return computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
}
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
SmallVector<Value> computeWeights;
if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter,
loc,
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
if (!blockArg.use_empty())
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock);
PimHaltOp::create(rewriter, computeOp.getLoc());
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,21 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
struct CoreLoweringState {
size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -6,16 +6,17 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -23,33 +24,33 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
unsigned inputCount = compute.getInputs().size();
if (inputCount == 0)
return std::nullopt;
unsigned inputBegin = compute->getNumOperands() - inputCount;
if (operandNumber < inputBegin)
return std::nullopt;
return operandNumber - inputBegin;
}
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner)) {
unsigned inputCount = computeBatch.getInputs().size();
if (inputCount == 0)
return std::nullopt;
unsigned inputBegin = computeBatch->getNumOperands() - inputCount;
if (operandNumber < inputBegin)
return std::nullopt;
return operandNumber - inputBegin;
}
return std::nullopt;
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
std::string name = baseName.str();
unsigned suffix = 0;
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
name = (baseName + "_" + Twine(suffix++)).str();
return name;
}
static memref::GlobalOp createPrivateMemrefGlobalWithUniqueName(PatternRewriter& rewriter,
Location loc,
ModuleOp moduleOp,
StringRef baseName,
MemRefType type,
Attribute initialValue = {},
UnitAttr constant = {}) {
std::string symbolName = makeUniqueSymbolName(moduleOp, baseName);
return memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(symbolName),
rewriter.getStringAttr("private"),
TypeAttr::get(type),
initialValue,
constant,
IntegerAttr {});
}
// Sinks top-level tensor slices into compute regions so later lowering sees local runtime work.
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
@@ -59,7 +60,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) {
if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber()))
if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber()))
return failure();
}
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
@@ -72,7 +73,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber());
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
@@ -87,14 +88,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber());
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
@@ -109,11 +107,8 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
}
else {
{
@@ -148,11 +143,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
};
// Turns runtime constants consumed by compute regions into private globals and local loads.
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
static int i = 0;
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
@@ -177,15 +172,14 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
std::string argName = "const_" + std::to_string(i++);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
constantOp.getValueAttr(),
rewriter.getUnitAttr(),
{});
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
loc,
constantOp->getParentOfType<ModuleOp>(),
"const",
memRefType,
constantOp.getValueAttr(),
rewriter.getUnitAttr());
std::string argName = globalOp.getSymName().str();
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
@@ -193,11 +187,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
@@ -206,18 +199,14 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
@@ -226,11 +215,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
replaceAndEraseDirectComputeLikeInput(rewriter,
spatComputeBatch.getOperation(),
BBArgIndex,
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
}
else {
{
@@ -272,34 +260,26 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
}
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) {
@@ -321,11 +301,13 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
}
}
}
rewriter.eraseOp(constantOp);
if (constantOp->use_empty())
rewriter.eraseOp(constantOp);
return success();
}
};
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
@@ -352,52 +334,36 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
mlir::MemRefType memRefType =
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
std::string argName = "arg_" + std::to_string(index);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
std::string baseName = ("arg_" + Twine(index)).str();
auto globalOp = createPrivateMemrefGlobalWithUniqueName(
rewriter, loc, funcOp->getParentOfType<ModuleOp>(), baseName, memRefType);
std::string argName = globalOp.getSymName().str();
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber());
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(toTensor);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber());
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatComputeBatch.getOperation());
BBArgValue.replaceAllUsesWith(toTensor);
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, toTensor);
}
else {
rewriter.setInsertionPoint(argUser);
@@ -416,7 +382,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
};
} // namespace
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
}
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
}
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
}
@@ -0,0 +1,20 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
bool hasFailure = false;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat")
return;
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
hasFailure = true;
});
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
@@ -0,0 +1,587 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
struct ReturnUseInfo {
size_t returnIndex;
SmallVector<Operation*> helperChain;
};
struct ConcatReturnUseInfo {
size_t returnIndex;
SmallVector<int64_t> sliceOffsets;
SmallVector<int64_t> concatShape;
SmallVector<Operation*> concatChain;
SmallVector<Operation*> helperChain;
};
static bool isReturnHelperChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
pim::PimTransposeOp>(op);
}
static void markOpToRemove(ReturnPathState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
std::string name = baseName.str();
unsigned suffix = 0;
while (SymbolTable::lookupSymbolIn(symbolTableOp, name))
name = (baseName + "_" + Twine(suffix++)).str();
return name;
}
static int64_t computeFlatElementIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> shape) {
int64_t flatIndex = 0;
for (size_t i = 0; i < shape.size(); ++i) {
flatIndex *= shape[i];
flatIndex += indices[i];
}
return flatIndex;
}
static SmallVector<int64_t> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
SmallVector<int64_t> indices(shape.size(), 0);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
indices[dim] = flatIndex % shape[dim];
flatIndex /= shape[dim];
}
return indices;
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin()))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isReturnHelperChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
auto uses = value.getUses();
if (rangeLength(uses) != 1)
return std::nullopt;
SmallVector<Operation*> helperChain;
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (isReturnHelperChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentValue = currentUser->getResult(0);
currentUser = currentUses.begin()->getOwner();
}
if (!isa<func::ReturnOp>(currentUser))
return std::nullopt;
return ReturnUseInfo {
currentValue.getUses().begin()->getOperandNumber(),
std::move(helperChain),
};
}
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatResult = [](Operation* op) -> Value {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getResult();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getOutput();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getOutput();
return {};
};
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getDim();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getAxis();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getAxis();
return std::nullopt;
};
auto getConcatOperands = [](Operation* op) -> OperandRange {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getOperands();
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
return spatialConcat.getInputs();
return cast<pim::PimConcatOp>(op).getInputs();
};
auto uses = value.getUses();
if (rangeLength(uses) != 1
|| !isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
return std::nullopt;
auto valueType = dyn_cast<ShapedType>(value.getType());
if (!valueType || !valueType.hasStaticShape())
return std::nullopt;
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
SmallVector<Operation*> concatChain;
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(currentUser)) {
concatChain.push_back(currentUser);
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
int64_t axis = *getConcatAxis(currentUser);
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
Value concatResult = getConcatResult(currentUser);
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
if (!concatType || !concatType.hasStaticShape())
return std::nullopt;
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
currentValue = concatResult;
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
SmallVector<Operation*> helperChain;
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
return std::nullopt;
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
return std::nullopt;
currentValue = helperCompute.getResult(0);
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentUser = currentUses.begin()->getOwner();
}
while (isReturnHelperChainOp(currentUser)) {
helperChain.push_back(currentUser);
auto currentUses = currentUser->getResult(0).getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
currentValue = currentUser->getResult(0);
currentUser = currentUses.begin()->getOwner();
}
if (!isa<func::ReturnOp>(currentUser))
return std::nullopt;
return ConcatReturnUseInfo {
currentValue.getUses().begin()->getOperandNumber(),
std::move(sliceOffsets),
std::move(concatShape),
std::move(concatChain),
std::move(helperChain),
};
}
static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndices,
ArrayRef<int64_t> sourceShape,
ArrayRef<Operation*> helperChain,
SmallVectorImpl<int64_t>& mappedIndices) {
SmallVector<int64_t> currentIndices(sourceIndices.begin(), sourceIndices.end());
SmallVector<int64_t> currentShape(sourceShape.begin(), sourceShape.end());
auto reshapeToResultShape = [&](Operation* op) -> LogicalResult {
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape);
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
currentIndices = expandFlatElementIndex(flatIndex, currentShape);
return success();
};
for (Operation* op : helperChain) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
return failure();
SmallVector<int64_t> nextIndices;
nextIndices.reserve(currentIndices.size());
for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices,
extractSliceOp.getStaticOffsets(),
extractSliceOp.getStaticSizes(),
extractSliceOp.getStaticStrides())) {
if (stride != 1 || index < offset || index >= offset + size)
return failure();
nextIndices.push_back(index - offset);
}
auto resultType = dyn_cast<ShapedType>(extractSliceOp.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
currentIndices = std::move(nextIndices);
currentShape.assign(resultType.getShape().begin(), resultType.getShape().end());
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
currentIndices = std::move(nextIndices);
currentShape = std::move(nextShape);
continue;
}
if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
currentIndices = std::move(nextIndices);
currentShape = std::move(nextShape);
continue;
}
if (isa<tensor::CastOp, tosa::ReshapeOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(op)) {
if (failed(reshapeToResultShape(op)))
return failure();
continue;
}
return failure();
}
mappedIndices.assign(currentIndices.begin(), currentIndices.end());
return success();
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static void
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
IRMapping mapping;
mapping.map(sourceValue, sourceValue);
clonedValue = sourceValue;
rewriter.setInsertionPointAfterValue(sourceValue);
for (Operation* op : helperChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
clonedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static Value emitHostCopy(IRRewriter& rewriter,
Location loc,
Value outputTensor,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
return PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
} // namespace
void addReturnOutputBuffers(func::ReturnOp returnOp,
IRRewriter& rewriter,
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Value currentReturnValue = returnValue;
Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(
[currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; });
}
else {
auto outRankedTensorType = llvm::dyn_cast<RankedTensorType>(currentReturnValue.getType());
auto memRefType = MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
std::string outputBaseName = ("output_" + Twine(index)).str();
std::string outputName = makeUniqueSymbolName(returnOp->getParentOfType<ModuleOp>(), outputBaseName);
rewriter.setInsertionPoint(returnOp.getParentOp());
memref::GlobalOp::create(rewriter,
returnOp.getLoc(),
rewriter.getStringAttr(outputName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
outputTensors.push_back([memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
return toTensor.getResult();
});
}
}
}
ReturnPathLoweringResult lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto yieldType = cast<TensorType>(yieldValue.getType());
if (auto returnUse = analyzeReturnUse(result)) {
Value storedValue = yieldValue;
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
for (Operation* op : returnUse->helperChain)
markOpToRemove(state, op);
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
auto resultUses = result.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
}
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(state, concatOp);
if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
yieldValue,
static_cast<int32_t>(flatOffset * elementSize),
0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
return ReturnPathLoweringResult::Handled;
}
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
SmallVector<int64_t> destinationIndices;
if (failed(mapIndicesThroughHelperChain(
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
SmallVector<OpFoldResult> extractOffsets;
SmallVector<OpFoldResult> extractSizes;
SmallVector<OpFoldResult> extractStrides;
extractOffsets.reserve(storedType.getRank());
extractSizes.reserve(storedType.getRank());
extractStrides.reserve(storedType.getRank());
for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) {
extractOffsets.push_back(rewriter.getIndexAttr(idx));
extractSizes.push_back(rewriter.getIndexAttr(1));
extractStrides.push_back(rewriter.getIndexAttr(1));
}
auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
outputTensor = emitHostCopy(rewriter,
loc,
outputTensor,
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
}
return ReturnPathLoweringResult::Handled;
}
return ReturnPathLoweringResult::NotReturnPath;
}
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op)
return;
bool isExclusivelyOwnedByReturnChain = op->use_empty();
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|| isReturnHelperChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
return;
if (isReturnHelperChainOp(op)) {
Value source = op->getOperand(0);
markOpToRemove(state, op);
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
markOpToRemove(state, computeOp);
if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs())
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
}
};
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
auto loc = returnOp.getLoc();
for (auto it : llvm::enumerate(originalOperands)) {
size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp);
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
}
}
} // namespace onnx_mlir
@@ -0,0 +1,37 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include <functional>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
struct ReturnPathState {
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
enum class ReturnPathLoweringResult {
Handled,
NotReturnPath,
Failure
};
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
mlir::IRRewriter& rewriter,
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
mlir::OpResult result,
mlir::Value yieldValue,
ReturnPathState& state,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,113 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
if (op.getDim() != 0)
return failure();
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
if (!packed)
return failure();
rewriter.replaceOp(op, packed);
return success();
}
};
} // namespace
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
packedShape[0] *= count;
return RankedTensorType::get(packedShape, elementType.getElementType());
}
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
if (values.empty())
return {};
if (values.size() == 1)
return values.front();
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
if (!firstSliceOp)
return {};
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|| firstType.getRank() == 0)
return {};
auto hasStaticValues = [](ArrayRef<int64_t> values) {
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
};
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
return {};
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
int64_t rowsPerValue = firstSizes[0];
if (ShapedType::isDynamic(rowsPerValue))
return {};
for (size_t index = 1; index < values.size(); ++index) {
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|| !hasStaticValues(sliceOp.getStaticStrides()))
return {};
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
return {};
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
return {};
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
return {};
}
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
offsets.reserve(firstType.getRank());
sizes.reserve(firstType.getRank());
strides.reserve(firstType.getRank());
offsets.push_back(builder.getIndexAttr(firstOffsets[0]));
sizes.push_back(builder.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
strides.push_back(builder.getIndexAttr(firstStrides[0]));
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
offsets.push_back(builder.getIndexAttr(firstOffsets[dim]));
sizes.push_back(builder.getIndexAttr(firstSizes[dim]));
strides.push_back(builder.getIndexAttr(firstStrides[dim]));
}
bool coversWholeSource = packedType == sourceType;
for (int64_t dim = 0; coversWholeSource && dim < sourceType.getRank(); ++dim)
coversWholeSource = firstOffsets[dim] == 0 && firstStrides[dim] == 1;
if (coversWholeSource)
return firstSliceOp.getSource();
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
.getResult();
}
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,13 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir