add constant folding and verification pass for pim host operations
better validation scripts output big refactors
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
|
||||
181
src/PIM/Pass/PimFoldHostConstantsPass.cpp
Normal file
181
src/PIM/Pass/PimFoldHostConstantsPass.cpp
Normal file
@@ -0,0 +1,181 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
|
||||
SmallVector<int64_t> originalStrides(rank, 1);
|
||||
SmallVector<int64_t> transposedStrides(rank, 1);
|
||||
for (int64_t dim = rank - 2; dim >= 0; --dim) {
|
||||
originalStrides[dim] = originalStrides[dim + 1] * tensorType.getShape()[dim + 1];
|
||||
transposedStrides[dim] = transposedStrides[dim + 1] * transposedShape[dim + 1];
|
||||
}
|
||||
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
SmallVector<int64_t> transposedIndices(rank);
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedIndices[dim] = originalIndices[perms[dim]];
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += transposedIndices[dim] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimTransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
||||
auto resultType = dyn_cast<MemRefType>(transposeOp.getOutRes().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto sourceGetGlobal = transposeOp.getData().getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!sourceGetGlobal)
|
||||
return failure();
|
||||
|
||||
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto sourceGlobal = lookupGlobalForGetGlobal(moduleOp, sourceGetGlobal);
|
||||
if (!sourceGlobal || !sourceGlobal.getConstant() || !sourceGlobal.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*sourceGlobal.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> perms;
|
||||
perms.reserve(transposeOp.getPerms().size());
|
||||
for (IntegerAttr attr : transposeOp.getPerms().getAsRange<IntegerAttr>())
|
||||
perms.push_back(attr.getInt());
|
||||
FailureOr<DenseElementsAttr> transposedAttr = transposeDenseElements(denseAttr, perms);
|
||||
if (failed(transposedAttr))
|
||||
return failure();
|
||||
|
||||
auto transposedShape = cast<RankedTensorType>(transposedAttr->getType()).getShape();
|
||||
if (!llvm::equal(transposedShape, resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
MemRefType globalType = resultType;
|
||||
|
||||
auto globalName = sourceGlobal.getName().str() + "__folded_transpose";
|
||||
unsigned suffix = 0;
|
||||
while (moduleOp.lookupSymbol(globalName))
|
||||
globalName = sourceGlobal.getName().str() + "__folded_transpose_" + std::to_string(++suffix);
|
||||
|
||||
auto visibility = rewriter.getStringAttr("private");
|
||||
OpBuilder moduleBuilder(moduleOp.getBodyRegion());
|
||||
moduleBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||
auto newGlobal = memref::GlobalOp::create(moduleBuilder,
|
||||
transposeOp.getLoc(),
|
||||
globalName,
|
||||
visibility,
|
||||
globalType,
|
||||
*transposedAttr,
|
||||
/*constant=*/true,
|
||||
sourceGlobal.getAlignmentAttr());
|
||||
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, transposeOp.getLoc(), globalType, newGlobal.getName());
|
||||
|
||||
bool isAlwaysWeight =
|
||||
!transposeOp->getUsers().empty()
|
||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
markWeightAlways(newGlobal);
|
||||
markWeightAlways(newGetGlobal);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PimFoldHostConstantsPass : PassWrapper<PimFoldHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimFoldHostConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "fold-pim-host-constants-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
RewritePatternSet owningPatterns(context);
|
||||
for (auto* dialect : context->getLoadedDialects())
|
||||
dialect->getCanonicalizationPatterns(owningPatterns);
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
owningPatterns.add<FoldConstantTransposePattern>(context);
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimFoldHostConstantsPass() { return std::make_unique<PimFoldHostConstantsPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
173
src/PIM/Pass/PimHostVerificationPass.cpp
Normal file
173
src/PIM/Pass/PimHostVerificationPass.cpp
Normal file
@@ -0,0 +1,173 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isAddressOnlyHostOp(Operation* op) {
|
||||
return isa<memref::AllocOp,
|
||||
memref::GetGlobalOp,
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp,
|
||||
spatial::SpatChannelNewOp>(op);
|
||||
}
|
||||
|
||||
static bool isHostAddressableValue(Value value) {
|
||||
while (true) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return isa<func::FuncOp>(blockArg.getOwner()->getParentOp());
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct PimHostVerificationPass : PassWrapper<PimHostVerificationPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimHostVerificationPass)
|
||||
|
||||
StringRef getArgument() const override { return "verify-pim-host-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Verify that no runtime host-side code remains in bufferized PIM IR";
|
||||
}
|
||||
|
||||
PimHostVerificationPass() {}
|
||||
PimHostVerificationPass(const PimHostVerificationPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
bool hasFailure = false;
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||
if (failed(verifyReturnOp(returnOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isAddressOnlyHostOp(&op)) {
|
||||
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(verifyAddressOnlyHostOp(&op)))
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure)
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as memref.get_global before JSON codegen";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||
if (!isHostAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by host-addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlySource(op, subviewOp.getSource());
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||
return verifyAddressOnlySource(op, castOp.getSource());
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||
if (isHostAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that still requires host-side execution");
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimHostVerificationPass() { return std::make_unique<PimHostVerificationPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,23 +3,26 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass();
|
||||
std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToGraphvizPass();
|
||||
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPIMPass();
|
||||
std::unique_ptr<mlir::Pass> createSpatialToPIMPass();
|
||||
|
||||
std::unique_ptr<Pass> createBufferizePimPass();
|
||||
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||
|
||||
std::unique_ptr<Pass> createEmitPimJsonPass();
|
||||
std::unique_ptr<mlir::Pass> createPimFoldHostConstantsPass();
|
||||
|
||||
std::unique_ptr<Pass> createMessagePass(std::string message);
|
||||
std::unique_ptr<mlir::Pass> createPimHostVerificationPass();
|
||||
|
||||
std::unique_ptr<Pass> createCountInstructionPass();
|
||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
|
||||
|
||||
std::unique_ptr<mlir::Pass> createCountInstructionPass();
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user