huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -3,6 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_pim_library(OMONNXToSpatial
|
||||
ConversionPatterns.cpp
|
||||
HostFoldability.cpp
|
||||
HostLegality.cpp
|
||||
PrePatterns.cpp
|
||||
PostPatterns.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/Elementwise.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -30,10 +32,29 @@ SmallVector<Value> sliceTensor(
|
||||
|
||||
for (int64_t i = 0; i < numSlices; i++) {
|
||||
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
||||
int64_t currentSliceSize = sliceSize;
|
||||
if (i == numSlices - 1 && lastSliceSize != 0) {
|
||||
currentSliceSize = lastSliceSize;
|
||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||
}
|
||||
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
|
||||
sliceShape[axis] = currentSliceSize;
|
||||
auto sliceType =
|
||||
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
|
||||
|
||||
Value slice;
|
||||
if (isHostFoldableValue(tensorToSlice)) {
|
||||
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
}
|
||||
else {
|
||||
auto sliceCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
|
||||
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
|
||||
});
|
||||
slice = sliceCompute.getResult(0);
|
||||
}
|
||||
slices.push_back(slice);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,15 +5,15 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <class ShapedType>
|
||||
@@ -105,7 +105,8 @@ inline auto getTensorShape(mlir::Value tensor) {
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape();
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -28,7 +28,7 @@ bool isWeightLikeComputeOperand(Value value) {
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (hasWeightAlways(definingOp))
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||
patterns.add<removeLRN>(ctx);
|
||||
|
||||
populateElementwisePatterns(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
populatePoolPatterns(patterns, ctx);
|
||||
populateReduceMeanPatterns(patterns, ctx);
|
||||
populateReluPatterns(patterns, ctx);
|
||||
populateSigmoidPatterns(patterns, ctx);
|
||||
populateSoftmaxPatterns(patterns, ctx);
|
||||
populateConcatPatterns(patterns, ctx);
|
||||
populateGatherPatterns(patterns, ctx);
|
||||
populateResizePatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
populateSplitPatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+2
@@ -5,6 +5,8 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
@@ -0,0 +1,75 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
return shapedType && shapedType.hasStaticShape();
|
||||
});
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||
return isHostFoldableValue(transposeOp.getData());
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||
return isHostFoldableValue(collapseShapeOp.getSrc());
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||
return isHostFoldableValue(expandShapeOp.getSrc());
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isHostFoldableValue(Value value) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(definingOp, visited);
|
||||
}
|
||||
|
||||
bool isHostFoldableOp(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,29 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
bool hasFailure = false;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
continue;
|
||||
if (isHostFoldableOp(&op))
|
||||
continue;
|
||||
|
||||
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
hasFailure = true;
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -8,21 +8,17 @@
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -33,8 +29,6 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
|
||||
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
|
||||
@@ -44,71 +38,64 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||
void populateEmptyFunction(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
||||
static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<spatial::SpatComputeBatch> batchOps;
|
||||
funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); });
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
return;
|
||||
|
||||
for (auto batchOp : batchOps) {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
continue;
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLocs;
|
||||
sourceTypes.reserve(funcOp.getNumArguments());
|
||||
sourceLocs.reserve(funcOp.getNumArguments());
|
||||
for (Value source : funcOp.getArguments()) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLocs.push_back(source.getLoc());
|
||||
}
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
|
||||
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
|
||||
mapper.map(computeArg, blockArg);
|
||||
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : funcOp.getOps())
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp.replaceAllUsesWith(computeOp.getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
}
|
||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
|
||||
op.dropAllUses();
|
||||
rewriter.eraseOp(&op);
|
||||
}
|
||||
|
||||
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
|
||||
RewritePatternSet mergeActivationPatterns(ctx);
|
||||
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemm>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populatePrePatterns(prePatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
||||
llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n";
|
||||
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||
|
||||
IRRewriter rewriter(moduleOp);
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
signalPassFailure();
|
||||
@@ -140,34 +127,23 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXSplitOp>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRN>(ctx);
|
||||
|
||||
populateElementwisePatterns(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
populatePoolPatterns(patterns, ctx);
|
||||
populateReduceMeanPatterns(patterns, ctx);
|
||||
populateReluPatterns(patterns, ctx);
|
||||
populateSigmoidPatterns(patterns, ctx);
|
||||
populateSoftmaxPatterns(patterns, ctx);
|
||||
populateConcatPatterns(patterns, ctx);
|
||||
populateGatherPatterns(patterns, ctx);
|
||||
populateResizePatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
populateSplitPatterns(patterns, ctx);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
RewritePatternSet conversionPatterns(ctx);
|
||||
populateConversionPatterns(conversionPatterns, ctx);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
foldSingleLaneComputeBatches(*entryFunc);
|
||||
RewritePatternSet earlyPostPatterns(ctx);
|
||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Count the number of compute ops and check they do not exceed the core count
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
@@ -185,355 +161,23 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
RewritePatternSet postPatterns(ctx);
|
||||
populatePostPatterns(postPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
populateEmptyFunction(*entryFunc);
|
||||
|
||||
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "spatial0");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
|
||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||
Value source = funcSource(toRemoveOp);
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
|
||||
auto source = toRemoveOp.getSource();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
||||
auto sources = toRemoveOp.getInputs();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (llvm::any_of(sources,
|
||||
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
}
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = spatial::SpatConcatOp::create(rewriter,
|
||||
loc,
|
||||
toRemoveOp.getType(),
|
||||
rewriter.getI64IntegerAttr(toRemoveOp.getDim()),
|
||||
ValueRange(BB->getArguments()));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
}
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static FailureOr<bool> sourceOperandHasWeightAlways(Operation* op) {
|
||||
if (op == nullptr)
|
||||
return false;
|
||||
|
||||
Operation* source = nullptr;
|
||||
do {
|
||||
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch>(*op)) {
|
||||
return false;
|
||||
}
|
||||
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
|
||||
auto tmpSource = extractSliceOp.getSource();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(*op)) {
|
||||
auto tmpSource = extractRowsOp.getInput();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
|
||||
auto tmpSource = expandShapeOp.getSrc();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
|
||||
auto tmpSource = transposeOp.getData();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
|
||||
auto tmpSource = collapseShapeOp.getSrc();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
|
||||
source = constantOp;
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
|
||||
bool res = false;
|
||||
for (auto operand : concatOp.getOperands()) {
|
||||
res |= hasWeightAlways(operand.getDefiningOp());
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(*op)) {
|
||||
bool res = false;
|
||||
for (auto operand : concatOp.getOperands()) {
|
||||
res |= hasWeightAlways(operand.getDefiningOp());
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
else {
|
||||
op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
while (source == nullptr);
|
||||
|
||||
return hasWeightAlways(source);
|
||||
}
|
||||
|
||||
// TODO what we want to keep in global?
|
||||
LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
bool keep = true;
|
||||
while (keep) {
|
||||
keep = false;
|
||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
|
||||
|| isa<func::ReturnOp>(instruction))
|
||||
continue;
|
||||
|
||||
auto weightBacked = sourceOperandHasWeightAlways(&instruction);
|
||||
if (failed(weightBacked))
|
||||
return failure();
|
||||
if (*weightBacked)
|
||||
continue;
|
||||
|
||||
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
||||
|
||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
||||
|
||||
keep |= encapsulator<ONNXTransposeOp>(
|
||||
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
|
||||
|
||||
keep |= encapsulator<tensor::CollapseShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
|
||||
|
||||
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
|
||||
for (auto compute : computes) {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto& oldBlock = compute.getBody().front();
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (auto& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
compute.replaceAllUsesWith(newCompute);
|
||||
compute.erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
return;
|
||||
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : funcOp.getArguments()) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(source.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
|
||||
mapper.map(computeArg, bbArg);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
for (Operation& inst : funcOp.getOps())
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
|
||||
rewriter.clone(inst, mapper);
|
||||
|
||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||
|
||||
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
|
||||
inst.dropAllUses();
|
||||
rewriter.eraseOp(&inst);
|
||||
}
|
||||
|
||||
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -5,9 +5,9 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
static Value transposeForSpatial(Value value,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
value,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
input,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -81,6 +121,11 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||
|
||||
if (isHostFoldableValue(matrix)) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
}
|
||||
|
||||
auto buildRowSlices = [&](Value matrixArg) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
@@ -122,7 +167,8 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
rootValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
return buildRowSlices(matrix);
|
||||
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -175,13 +221,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||
cType = expandedType;
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
@@ -196,25 +236,18 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
SmallVector<Value> cSlices;
|
||||
if (hasC && cHasNumOutRows)
|
||||
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
gemvOps.reserve(static_cast<size_t>(numOutRows));
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
if (cHasNumOutRows)
|
||||
cSlice = cSlices[static_cast<size_t>(rowIdx)];
|
||||
else if (!isVectorShape(getTensorShape(c))) {
|
||||
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||
return failure();
|
||||
@@ -224,7 +257,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
aSlices[static_cast<size_t>(rowIdx)],
|
||||
b,
|
||||
cSlice,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
@@ -267,13 +300,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
gemmLoc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
|
||||
cType = expandedType;
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
@@ -305,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
|
||||
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||
aType = cast<RankedTensorType>(a.getType());
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
@@ -335,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
@@ -469,12 +496,15 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
if (!isHostFoldableValue(b))
|
||||
return failure();
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
@@ -484,13 +514,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto cType = cast<RankedTensorType>(c.getType());
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -36,49 +36,27 @@ static Value extractBatchMatrix(Value value,
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
|
||||
|
||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
matrixType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
auto buildMatrix = [&](Value input) -> Value {
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
matrixType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
};
|
||||
|
||||
static bool isConstantLikeOperand(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
if (isHostFoldableValue(value))
|
||||
return buildMatrix(value);
|
||||
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
value = transposeOp.getData();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
auto batchMatrixCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
||||
});
|
||||
return batchMatrixCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -107,15 +85,31 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -157,7 +151,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
@@ -193,8 +187,14 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm)
|
||||
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
|
||||
if (useTransposedForm) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
gemmResult = transposeCompute.getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
@@ -215,24 +215,30 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm)
|
||||
gemmResult = ONNXTransposeOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
gemmResult,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
gemmResult,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
}));
|
||||
auto batchResultCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
||||
Value resultMatrix = input;
|
||||
if (useTransposedForm) {
|
||||
resultMatrix = ONNXTransposeOp::create(rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
input,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
resultMatrix,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
batchResults.push_back(batchResultCompute.getResult(0));
|
||||
}
|
||||
|
||||
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults);
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildReduceMeanKeepdims(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
int64_t axis,
|
||||
@@ -100,7 +119,7 @@ static Value buildReduceMeanKeepdims(Value input,
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, reducedSlices);
|
||||
return concatValues(reducedSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
@@ -115,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||
}
|
||||
|
||||
return tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
|
||||
.getResult();
|
||||
auto reassociation = buildCollapseReassociation(reducedAxes);
|
||||
if (isHostFoldableValue(keepdimsValue))
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||
|
||||
auto squeezeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
|
||||
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
|
||||
});
|
||||
return squeezeCompute.getResult(0);
|
||||
}
|
||||
|
||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
|
||||
@@ -31,8 +31,8 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
static FailureOr<Value>
|
||||
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
||||
static FailureOr<Value> concatAlongAxis(
|
||||
ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
||||
if (values.empty()) {
|
||||
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
|
||||
return failure();
|
||||
@@ -68,8 +68,8 @@ reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation*
|
||||
return reduced;
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
||||
static FailureOr<Value> scaleAverageWindow(
|
||||
ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
||||
if (divisor <= 0) {
|
||||
op->emitOpError("AveragePool divisor must be positive");
|
||||
return failure();
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -32,6 +33,24 @@ static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewrit
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
@@ -47,7 +66,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, rebuiltSlices);
|
||||
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
@@ -92,8 +111,13 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
result = postTransposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(softmaxOp, result);
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -18,7 +20,17 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
auto inputs = adaptor.getInputs();
|
||||
int64_t axis = adaptor.getAxis();
|
||||
|
||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||
if (llvm::all_of(inputs, isHostFoldableValue)) {
|
||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(
|
||||
rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args));
|
||||
});
|
||||
rewriter.replaceOp(maxpoolOp, computeOp.getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -95,18 +98,33 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
||||
if (isHostFoldableValue(adaptor.getData())) {
|
||||
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
||||
Value reshaped = buildReshape(data);
|
||||
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
|
||||
});
|
||||
rewriter.replaceOp(reshapeOp, computeOp.getResults());
|
||||
return success();
|
||||
};
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
if (sourceType.getRank() > resultType.getRank()
|
||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
|
||||
return replaceWithReshape([&](Value data) {
|
||||
return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
if (sourceType.getRank() < resultType.getRank()
|
||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
|
||||
return replaceWithReshape([&](Value data) {
|
||||
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -47,16 +49,40 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
outputs.reserve(splitOp.getNumResults());
|
||||
|
||||
int64_t offset = 0;
|
||||
SmallVector<RankedTensorType> resultTypes;
|
||||
resultTypes.reserve(splitOp.getNumResults());
|
||||
SmallVector<int64_t> sliceSizes;
|
||||
sliceSizes.reserve(splitOp.getNumResults());
|
||||
for (Value result : splitOp.getResults()) {
|
||||
auto resultType = dyn_cast<RankedTensorType>(result.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
int64_t sliceSize = resultType.getShape()[axis];
|
||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
offset += sliceSize;
|
||||
resultTypes.push_back(resultType);
|
||||
sliceSizes.push_back(resultType.getShape()[axis]);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(splitOp, outputs);
|
||||
if (isHostFoldableValue(adaptor.getInput())) {
|
||||
for (int64_t sliceSize : sliceSizes) {
|
||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
offset += sliceSize;
|
||||
}
|
||||
rewriter.replaceOp(splitOp, outputs);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, splitOp.getLoc(), TypeRange(splitOp.getResultTypes()), {}, adaptor.getInput(), [&](Value input) {
|
||||
SmallVector<Value> runtimeOutputs;
|
||||
runtimeOutputs.reserve(resultTypes.size());
|
||||
int64_t runtimeOffset = 0;
|
||||
for (int64_t sliceSize : sliceSizes) {
|
||||
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
runtimeOffset += sliceSize;
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(splitOp, computeOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isWeightMaterializationHelperUser(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||
}
|
||||
|
||||
static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(templateBlock.getNumArguments());
|
||||
blockArgLocs.reserve(templateBlock.getNumArguments());
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp->replaceAllUsesWith(computeOp->getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatComputeBatch::create(rewriter,
|
||||
compute.getLoc(),
|
||||
compute.getResultTypes(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
|
||||
}
|
||||
|
||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||
}
|
||||
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||
patterns.add<onnxToArithConstant>(ctx);
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
patterns.add<matMulAddToGemm>(ctx);
|
||||
patterns.add<matMulToGemm>(ctx);
|
||||
patterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user