big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
@@ -34,35 +33,6 @@ static int64_t getValueSizeInBytes(Value value) {
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
static void expandPimMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<pim::PimMapOp> mapOps;
|
||||
funcOp.walk([&](pim::PimMapOp mapOp) { mapOps.push_back(mapOp); });
|
||||
|
||||
for (auto mapOp : mapOps) {
|
||||
Block& body = mapOp.getBody().front();
|
||||
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
|
||||
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(mapOp.getInputs().size());
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
for (Value input : mapOp.getInputs()) {
|
||||
IRMapping mapping;
|
||||
mapping.map(body.getArgument(0), input);
|
||||
|
||||
for (Operation& op : body.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapping.map(originalResult, clonedResult);
|
||||
rewriter.setInsertionPointAfter(cloned);
|
||||
}
|
||||
|
||||
replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0)));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(mapOp, replacements);
|
||||
}
|
||||
}
|
||||
|
||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||
|
||||
@@ -80,8 +50,6 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
if (funcOp.isExternal())
|
||||
continue;
|
||||
|
||||
expandPimMapOps(funcOp, rewriter);
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
|
||||
|
||||
@@ -150,38 +118,11 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
|
||||
SmallVector<Operation*> hostCompactOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (isa<pim::PimExtractRowsOp, pim::PimConcatOp>(op))
|
||||
if (isa<pim::PimConcatOp>(op))
|
||||
hostCompactOps.push_back(&op);
|
||||
|
||||
for (Operation* op : hostCompactOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op)) {
|
||||
auto inputType = dyn_cast<ShapedType>(extractRowsOp.getInput().getType());
|
||||
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() != 2) {
|
||||
extractRowsOp.emitOpError("host-side extract_rows lowering requires a static rank-2 input");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t numCols = inputType.getDimSize(1);
|
||||
SmallVector<Value> replacementRows;
|
||||
replacementRows.reserve(extractRowsOp.getOutputs().size());
|
||||
for (auto rowIndex : llvm::seq<size_t>(0, extractRowsOp.getOutputs().size())) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
replacementRows.push_back(memref::SubViewOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult());
|
||||
}
|
||||
|
||||
extractRowsOp->replaceAllUsesWith(ValueRange(replacementRows));
|
||||
extractRowsOp->erase();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto concatOp = cast<pim::PimConcatOp>(op);
|
||||
concatOp.emitOpError("host-side concat must be folded away or lowered into pim.core before materialization");
|
||||
hasFailure = true;
|
||||
|
||||
Reference in New Issue
Block a user