reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
@@ -8,6 +10,7 @@
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -240,8 +243,129 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||
return true;
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
|
||||
if (!knowledge)
|
||||
return value;
|
||||
|
||||
auto iter = knowledge->aliases.find(value);
|
||||
while (iter != knowledge->aliases.end()) {
|
||||
value = iter->second;
|
||||
iter = knowledge->aliases.find(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
|
||||
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
|
||||
// and when propagating yielded values across iterations during static unrolling.
|
||||
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return value;
|
||||
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
if (auto result = dyn_cast<OpResult>(value))
|
||||
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (knowledge) {
|
||||
auto iter = knowledge->indexValues.find(value);
|
||||
if (iter != knowledge->indexValues.end())
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||
if (constantOp) {
|
||||
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
|
||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||
|
||||
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs + *rhs;
|
||||
}
|
||||
|
||||
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs - *rhs;
|
||||
}
|
||||
|
||||
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs * *rhs;
|
||||
}
|
||||
|
||||
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!integerAttr)
|
||||
return failure();
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
|
||||
}
|
||||
|
||||
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
int64_t byteOffset = 0;
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
while (true) {
|
||||
if (isa<BlockArgument>(value))
|
||||
@@ -255,7 +379,29 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return failure();
|
||||
value = tiedOperand->get();
|
||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result)
|
||||
return failure();
|
||||
|
||||
// Trace the loop carry back to its underlying memref, then if that memref is the
|
||||
// loop's own iter-arg we know the base comes from the corresponding init arg
|
||||
// (every iteration yields the same backing memory in the DPS sense).
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -265,31 +411,53 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
|
||||
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|
||||
|| llvm::is_contained(strides, ShapedType::kDynamic))
|
||||
return failure();
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> sizes;
|
||||
SmallVector<int64_t> strides;
|
||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||
strides.reserve(subviewOp.getMixedStrides().size());
|
||||
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return failure();
|
||||
offsets.push_back(*resolvedOffset);
|
||||
}
|
||||
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||
if (failed(resolvedSize))
|
||||
return failure();
|
||||
sizes.push_back(*resolvedSize);
|
||||
}
|
||||
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||
if (failed(resolvedStride))
|
||||
return failure();
|
||||
strides.push_back(*resolvedStride);
|
||||
}
|
||||
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
value = subviewOp.getSource();
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
value = resolveAlias(castOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -300,4 +468,79 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||
|
||||
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
return resolveContiguousAddressImpl(value, nullptr);
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveContiguousAddressImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
bool isCoreStaticAddressOp(Operation* op) {
|
||||
return isa<arith::ConstantOp,
|
||||
arith::AddIOp,
|
||||
arith::SubIOp,
|
||||
arith::MulIOp,
|
||||
arith::DivUIOp,
|
||||
arith::RemUIOp,
|
||||
arith::IndexCastOp,
|
||||
memref::AllocOp,
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
LogicalResult walkPimCoreBlock(Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (Operation& op : block) {
|
||||
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
|
||||
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
|
||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user