extend operation support for conv and gemm
add more tests in validation
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
@@ -79,8 +80,31 @@ private:
|
||||
} // namespace
|
||||
|
||||
static bool isChannelUseChainOp(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
||||
op);
|
||||
return isa<tensor::ExtractSliceOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CastOp,
|
||||
tosa::ReshapeOp,
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
}
|
||||
|
||||
static size_t countComputeLeafUsers(Value value) {
|
||||
@@ -204,6 +228,56 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isChannelUseChainOp(resultUser)) {
|
||||
SmallVector<Operation*> returnChain;
|
||||
Value chainedValue = result;
|
||||
Operation* chainUser = resultUser;
|
||||
|
||||
while (isChannelUseChainOp(chainUser)) {
|
||||
returnChain.push_back(chainUser);
|
||||
auto chainUses = chainUser->getResult(0).getUses();
|
||||
if (rangeLength(chainUses) != 1)
|
||||
break;
|
||||
chainedValue = chainUser->getResult(0);
|
||||
chainUser = chainUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(chainUser)) {
|
||||
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
IRMapping mapping;
|
||||
mapping.map(result, yieldValue);
|
||||
|
||||
Value storedValue = yieldValue;
|
||||
for (Operation* op : returnChain) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
storedValue = clonedOp->getResult(0);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
markOpToRemove(op);
|
||||
}
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
storedValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t offset = 0;
|
||||
@@ -493,6 +567,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
||||
IRMapping mapping;
|
||||
mapping.map(channelSourceOp, receivedValue);
|
||||
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
|
||||
Reference in New Issue
Block a user