better reports refactor for more code-reuse and patter usage fixes
This commit is contained in:
@@ -94,7 +94,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populatePrePatterns(prePatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
||||
llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n";
|
||||
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
@@ -148,7 +148,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
llvm::dbgs() << "Number of compute ops exceeds the core count\n";
|
||||
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
|
||||
<< coresCount << ")";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -157,7 +158,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
if (failed(cleanupPM.run(moduleOp)))
|
||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||
moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing");
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
|
||||
@@ -70,85 +70,6 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) {
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
||||
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId()));
|
||||
|
||||
rewriter.setInsertionPoint(sendOp);
|
||||
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
||||
rewriter.eraseOp(sendOp);
|
||||
}
|
||||
|
||||
static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
||||
if (receiveOp->use_empty()) {
|
||||
rewriter.eraseOp(receiveOp);
|
||||
return;
|
||||
}
|
||||
|
||||
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
||||
rewriter.setInsertionPoint(receiveOp);
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
||||
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
||||
|
||||
Value received =
|
||||
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(receiveOp, received);
|
||||
}
|
||||
|
||||
static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp, IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendTensorOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendTensorOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
|
||||
rewriter.setInsertionPoint(sendTensorOp);
|
||||
PimSendTensorOp::create(
|
||||
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.eraseOp(sendTensorOp);
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveTensor(spatial::SpatChannelReceiveTensorOp receiveTensorOp, IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveTensorOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveTensorOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
|
||||
rewriter.setInsertionPoint(receiveTensorOp);
|
||||
auto outputType = cast<ShapedType>(receiveTensorOp.getOutput().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType).getResult();
|
||||
Value received = PimReceiveTensorOp::create(rewriter,
|
||||
receiveTensorOp.getLoc(),
|
||||
receiveTensorOp.getOutput().getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(receiveTensorOp, received);
|
||||
}
|
||||
|
||||
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
||||
rewriter.setInsertionPoint(extractRowsOp);
|
||||
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(extractRowsOp.getNumResults());
|
||||
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
||||
auto outputType = cast<RankedTensorType>(output.getType());
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(inputType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
replacements.push_back(
|
||||
tensor::ExtractSliceOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult());
|
||||
}
|
||||
rewriter.replaceOp(extractRowsOp, replacements);
|
||||
}
|
||||
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
@@ -216,97 +137,6 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
for (auto concatOp : concatOps) {
|
||||
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
||||
continue;
|
||||
|
||||
SmallVector<Value> packedInputs;
|
||||
bool changed = false;
|
||||
rewriter.setInsertionPoint(concatOp);
|
||||
|
||||
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
||||
Value input = concatOp.getInputs()[index];
|
||||
|
||||
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()
|
||||
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
|
||||
++endIndex;
|
||||
|
||||
Value packedInput = createPackedExtractSliceTensor(
|
||||
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
index = endIndex;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto result = dyn_cast<OpResult>(input);
|
||||
if (!result) {
|
||||
packedInputs.push_back(input);
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* owner = result.getOwner();
|
||||
unsigned startIndex = result.getResultNumber();
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
||||
if (!nextResult || nextResult.getOwner() != owner
|
||||
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
||||
break;
|
||||
++endIndex;
|
||||
}
|
||||
|
||||
unsigned count = endIndex - index;
|
||||
Value packedInput;
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
||||
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
}
|
||||
else {
|
||||
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
||||
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
||||
}
|
||||
|
||||
index = endIndex;
|
||||
}
|
||||
|
||||
if (!changed)
|
||||
continue;
|
||||
|
||||
auto newConcat = pim::PimConcatOp::create(
|
||||
rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast<ShapedType>(concatOp.getOutput().getType()))
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
}
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
};
|
||||
eraseUnusedOps(tensor::ConcatOp {});
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
coreId = 1;
|
||||
ModuleOp moduleOp = getOperation();
|
||||
@@ -380,7 +210,12 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
compactSpatialTensorGroups(funcOp, rewriter);
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTensorPackingPatterns(patterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||
@@ -392,37 +227,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
markOpToRemove(receiveOp);
|
||||
continue;
|
||||
}
|
||||
if (receiveOp->use_empty()) {
|
||||
rewriter.eraseOp(receiveOp);
|
||||
continue;
|
||||
}
|
||||
lowerChannelReceive(receiveOp, rewriter);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveTensorOp> receiveTensorOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveTensorOp>())
|
||||
receiveTensorOps.push_back(op);
|
||||
for (auto receiveTensorOp : receiveTensorOps)
|
||||
lowerChannelReceiveTensor(receiveTensorOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelSendOp> sendOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
|
||||
sendOps.push_back(op);
|
||||
for (auto sendOp : sendOps)
|
||||
lowerChannelSend(sendOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelSendTensorOp> sendTensorOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelSendTensorOp>())
|
||||
sendTensorOps.push_back(op);
|
||||
for (auto sendTensorOp : sendTensorOps)
|
||||
lowerChannelSendTensor(sendTensorOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
|
||||
extractRowsOps.push_back(op);
|
||||
for (auto extractRowsOp : extractRowsOps)
|
||||
lowerExtractRows(extractRowsOp, rewriter);
|
||||
|
||||
{
|
||||
RewritePatternSet coreBodyPatterns(ctx);
|
||||
populateWithGenerated(coreBodyPatterns);
|
||||
@@ -457,7 +263,12 @@ void SpatialToPimPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
compactSpatialTensorGroups(funcOp, rewriter);
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTensorPackingPatterns(patterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
}
|
||||
|
||||
{
|
||||
ConversionTarget communicationTarget(*ctx);
|
||||
|
||||
@@ -1,8 +1,96 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConcatOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatConcatOp concatOp, PatternRewriter& rewriter) const override {
|
||||
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> packedInputs;
|
||||
bool changed = false;
|
||||
|
||||
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
||||
Value input = concatOp.getInputs()[index];
|
||||
|
||||
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()
|
||||
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
|
||||
++endIndex;
|
||||
|
||||
Value packedInput = createPackedExtractSliceTensor(
|
||||
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
index = endIndex;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto result = dyn_cast<OpResult>(input);
|
||||
if (!result) {
|
||||
packedInputs.push_back(input);
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* owner = result.getOwner();
|
||||
unsigned startIndex = result.getResultNumber();
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
||||
if (!nextResult || nextResult.getOwner() != owner
|
||||
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
||||
break;
|
||||
++endIndex;
|
||||
}
|
||||
|
||||
unsigned count = endIndex - index;
|
||||
Value packedInput;
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
||||
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
}
|
||||
else {
|
||||
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
||||
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
||||
}
|
||||
|
||||
index = endIndex;
|
||||
}
|
||||
|
||||
if (!changed)
|
||||
return failure();
|
||||
|
||||
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
||||
auto newConcat = pim::PimConcatOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
tensor::EmptyOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
outputType.getShape(),
|
||||
outputType.getElementType())
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
@@ -146,4 +234,23 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca
|
||||
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<PackSpatialConcatInputsPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
void eraseUnusedTensorPackingOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
};
|
||||
eraseUnusedOps(tensor::ConcatOp {});
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -19,5 +20,7 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO
|
||||
mlir::OpBuilder& builder,
|
||||
mlir::Location loc);
|
||||
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
||||
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user