rework actually broken dcp merge + compute re-batching (still to refine)
This commit is contained in:
@@ -23,21 +23,42 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
unsigned inputCount = compute.getInputs().size();
|
||||
if (inputCount == 0)
|
||||
return std::nullopt;
|
||||
|
||||
unsigned inputBegin = compute->getNumOperands() - inputCount;
|
||||
if (operandNumber < inputBegin)
|
||||
return std::nullopt;
|
||||
return operandNumber - inputBegin;
|
||||
}
|
||||
|
||||
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner)) {
|
||||
unsigned inputCount = computeBatch.getInputs().size();
|
||||
if (inputCount == 0)
|
||||
return std::nullopt;
|
||||
|
||||
unsigned inputBegin = computeBatch->getNumOperands() - inputCount;
|
||||
if (operandNumber < inputBegin)
|
||||
return std::nullopt;
|
||||
return operandNumber - inputBegin;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
||||
Location loc = extractSliceOp.getLoc();
|
||||
|
||||
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
||||
return failure();
|
||||
|
||||
for (auto& uses : extractSliceOp->getUses()) {
|
||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
|
||||
if (spatCompute.getInputs().empty())
|
||||
return failure();
|
||||
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
|
||||
if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||
return failure();
|
||||
}
|
||||
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
||||
@@ -50,7 +71,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
@@ -69,7 +93,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
@@ -165,8 +192,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
|
||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
@@ -183,8 +212,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
|
||||
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
@@ -201,7 +232,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
else {
|
||||
{
|
||||
{
|
||||
|
||||
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
@@ -240,8 +271,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
|
||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
@@ -253,8 +286,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
|
||||
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
@@ -265,11 +300,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
else {
|
||||
if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
||||
@@ -285,9 +319,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto parent = constantOp->getParentOp();
|
||||
rewriter.eraseOp(constantOp);
|
||||
return success();
|
||||
}
|
||||
@@ -333,7 +365,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||
auto argUser = argUses.getOwner();
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
||||
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
@@ -347,7 +382,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
||||
auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
|
||||
@@ -11,20 +11,15 @@
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <filesystem>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@@ -34,10 +29,8 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
@@ -214,11 +207,12 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
||||
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
|
||||
return;
|
||||
}
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto rowSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
||||
replacements.push_back(rowSlice.getResult());
|
||||
}
|
||||
|
||||
@@ -263,19 +257,19 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
|
||||
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||
if (!chainSet.contains(&op)
|
||||
&& !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||
return failure();
|
||||
|
||||
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||
return false;
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(),
|
||||
[](Operation* user) { return isa<spatial::SpatComputeBatch, pim::PimCoreBatchOp>(user); }))
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
}))
|
||||
return false;
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
@@ -447,8 +441,7 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(extractSliceOp.getStaticOffsets())
|
||||
|| !hasStaticValues(extractSliceOp.getStaticSizes())
|
||||
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
|
||||
return failure();
|
||||
|
||||
@@ -510,10 +503,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
||||
return success();
|
||||
}
|
||||
|
||||
static void cloneHelperChain(Value sourceValue,
|
||||
ArrayRef<Operation*> helperChain,
|
||||
IRRewriter& rewriter,
|
||||
Value& clonedValue) {
|
||||
static void
|
||||
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
||||
IRMapping mapping;
|
||||
mapping.map(sourceValue, sourceValue);
|
||||
clonedValue = sourceValue;
|
||||
@@ -734,7 +725,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter))
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
||||
return;
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
@@ -835,7 +826,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
|
||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||
if (!storedType) {
|
||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
computeOp.emitOpError(
|
||||
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -848,10 +840,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
|
||||
|
||||
SmallVector<int64_t> destinationIndices;
|
||||
if (failed(mapIndicesThroughHelperChain(sourceIndices,
|
||||
concatReturnUse->concatShape,
|
||||
concatReturnUse->helperChain,
|
||||
destinationIndices))) {
|
||||
if (failed(mapIndicesThroughHelperChain(
|
||||
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -897,9 +887,12 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
||||
|
||||
// Replace `spat.compute` with `pim.core`
|
||||
SmallVector<Value> computeWeights;
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = PimCoreOp::create(
|
||||
rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
||||
if (!blockArg.use_empty())
|
||||
@@ -933,15 +926,19 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
||||
}
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
SmallVector<Value> batchInputs;
|
||||
if (!computeBatchOp.getInputs().empty())
|
||||
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
|
||||
|
||||
rewriter.setInsertionPointAfter(computeBatchOp);
|
||||
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
||||
computeBatchOp.getWeights(),
|
||||
computeBatchOp.getInputs());
|
||||
ValueRange(batchWeights),
|
||||
ValueRange(batchInputs));
|
||||
coreBatchOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(computeBatchOp.getWeights().size()), static_cast<int>(computeBatchOp.getInputs().size())});
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
@@ -1124,13 +1121,13 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
||||
std::string outputName = "output_" + std::to_string(index);
|
||||
rewriter.setInsertionPoint(returnOp.getParentOp());
|
||||
memref::GlobalOp::create(rewriter,
|
||||
returnOp.getLoc(),
|
||||
rewriter.getStringAttr(outputName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
{},
|
||||
{},
|
||||
{});
|
||||
returnOp.getLoc(),
|
||||
rewriter.getStringAttr(outputName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
{},
|
||||
{},
|
||||
{});
|
||||
outputTensors.push_back(
|
||||
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
||||
@@ -1210,8 +1207,9 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
markOpToRemove(computeOp);
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
if (!computeOp.getInputs().empty())
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user