reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-04-10 18:50:25 +02:00
parent f3a36e9d43
commit f054e66ed0
18 changed files with 623 additions and 241 deletions

View File

@@ -54,6 +54,8 @@ add_pim_library(OMPIMAccel
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
MLIRSCFDialect
MLIRSCFTransforms
onnx
OMAccelerator
OMPimCompilerUtils

View File

@@ -1,4 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -8,6 +10,7 @@
#include <fstream>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -240,8 +243,129 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
return true;
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
// and when propagating yielded values across iterations during static unrolling.
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
if (auto result = dyn_cast<OpResult>(value))
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs + *rhs;
}
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs - *rhs;
}
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs * *rhs;
}
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return failure();
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
if (!integerAttr)
return failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
}
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (isa<BlockArgument>(value))
@@ -255,7 +379,29 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = tiedOperand->get();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
auto result = dyn_cast<OpResult>(value);
if (!result)
return failure();
// Trace the loop carry back to its underlying memref, then if that memref is the
// loop's own iter-arg we know the base comes from the corresponding init arg
// (every iteration yields the same backing memory in the DPS sense).
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
@@ -265,31 +411,53 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|| llvm::is_contained(strides, ShapedType::kDynamic))
return failure();
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return failure();
offsets.push_back(*resolvedOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return failure();
sizes.push_back(*resolvedSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = subviewOp.getSource();
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = castOp.getSource();
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
@@ -300,4 +468,79 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
}
}
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
bool isCoreStaticAddressOp(Operation* op) {
return isa<arith::ConstantOp,
arith::AddIOp,
arith::SubIOp,
arith::MulIOp,
arith::DivUIOp,
arith::RemUIOp,
arith::IndexCastOp,
memref::AllocOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp>(op);
}
LogicalResult walkPimCoreBlock(Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (Operation& op : block) {
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir

View File

@@ -6,6 +6,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
@@ -20,6 +21,13 @@ struct ResolvedContiguousAddress {
int64_t byteOffset = 0;
};
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
@@ -52,5 +60,29 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
/// only contribute to static addressing or index computations (arith integer math,
/// memref view ops, memref.alloc, arith.constant).
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir

View File

@@ -84,8 +84,8 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
auto resolvedAddress = resolveContiguousAddress(value);
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge) const {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (failed(resolvedAddress)) {
errs() << "Failed to resolve contiguous address for value: ";
value.print(errs());
@@ -199,47 +199,49 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("ld",
memory.getValueAddress(loadOp.getDeviceTarget()),
addressOf(loadOp.getDeviceTarget(), knowledge),
loadOp.getDeviceTargetOffset(),
memory.getValueAddress(loadOp.getHostSource()),
addressOf(loadOp.getHostSource(), knowledge),
loadOp.getHostSourceOffset(),
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("st",
memory.getValueAddress(storeOp.getHostTarget()),
addressOf(storeOp.getHostTarget(), knowledge),
storeOp.getHostTargetOffset(),
memory.getValueAddress(storeOp.getDeviceSource()),
addressOf(storeOp.getDeviceSource(), knowledge),
storeOp.getDeviceSourceOffset(),
storeOp.getSize());
}
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("lmv",
memory.getValueAddress(lmvOp.getTarget()),
addressOf(lmvOp.getTarget(), knowledge),
lmvOp.getTargetOffset(),
memory.getValueAddress(lmvOp.getSource()),
addressOf(lmvOp.getSource(), knowledge),
lmvOp.getSourceOffset(),
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp(
"recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize());
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize());
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
emitMvmOp(
mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0);
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
MVMTy mvmLikeOp,
bool transposeMatrix,
const StaticValueKnowledge& knowledge) {
emitMvmOp(mvmId, addressOf(mvmLikeOp.getOutputBuffer(), knowledge), 0, addressOf(mvmLikeOp.getInput(), knowledge), 0);
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
@@ -249,10 +251,10 @@ static size_t getValueSizeInBytes(mlir::Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs());
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -265,10 +267,10 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs());
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvsubOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvsubOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -281,10 +283,10 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs());
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvmulOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvmulOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -297,10 +299,10 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs());
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvmaxOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvmaxOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -313,10 +315,10 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs());
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvdmulOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvdmulOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
@@ -329,9 +331,9 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vavgOp.getInput());
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vavgOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vavgOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -344,9 +346,9 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vreluOp.getInput());
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vreluOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vreluOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -358,9 +360,9 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vtanhOp.getInput());
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vtanhOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vtanhOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -372,9 +374,9 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vsigmOp.getInput());
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vsigmOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vsigmOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -386,9 +388,9 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const {
auto outputBufferAddr = memory.getValueAddress(vsoftmaxOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vsoftmaxOp.getInput());
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vsoftmaxOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
@@ -400,9 +402,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const {
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
auto srcAddr = memory.getValueAddress(transposeOp.getInput());
auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer());
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape();
@@ -510,57 +512,58 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
}
/// Dispatch all operations in a core region to the appropriate code generator.
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
/// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
size_t processedOperations = 0;
for (auto& op : coreOp.getBody().front()) {
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
continue;
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
coreCodeGen.codeGenLoadOp(loadOp);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
coreCodeGen.codeGenLmvOp(lmvOp);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVVAddOp(vvaddOp);
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVVSubOp(vvsubOp);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp);
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp);
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
return -1;
}
processedOperations++;
}
return processedOperations;
auto result =
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
coreCodeGen.codeGenLoadOp(loadOp, knowledge);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp, knowledge);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVVAddOp(vvaddOp, knowledge);
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVVSubOp(vvsubOp, knowledge);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp, knowledge);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp, knowledge);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp, knowledge);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp, knowledge);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp, knowledge);
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp, knowledge);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
return failure();
}
processedOperations++;
return success();
});
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
/// Write crossbar weight matrices as padded binary files for a single core.
@@ -739,7 +742,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
PimCodeGen coreCodeGen(memory, coreFileStream);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp, coreCodeGen);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);

View File

@@ -4,6 +4,7 @@
#include "llvm/Support/JSON.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
@@ -50,13 +51,17 @@ public:
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value) const;
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
};
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreFileStream;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge);
}
static llvm::json::Object createEmptyOffset();
void emitInstruction(llvm::json::Object instruction) const;
@@ -80,27 +85,27 @@ public:
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp) const;
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
void codeGenSendOp(pim::PimSendOp sendOp) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const;
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const;
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const;
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const;
void codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const;
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
};
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);

View File

@@ -26,6 +26,7 @@ add_pim_library(OMONNXToSpatial
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
MLIRSCFDialect
MLIRTosaDialect
OMCompilerOptions
OMPimCompilerOptions

View File

@@ -1,4 +1,5 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -75,7 +76,11 @@ void ONNXToSpatialPass::runOnOperation() {
}
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>(
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXAddOp>();

View File

@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -169,44 +170,60 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
paddedInput = padOp.getResult();
}
// Build im2col [numPatches, patchSize]:
// For each batch/output position (n, oh, ow), extract the patch from x
SmallVector<Value> im2colRows;
im2colRows.reserve(numPatches);
for (int64_t n = 0; n < batchSize; n++) {
for (int64_t oh = 0; oh < outHeight; oh++) {
for (int64_t ow = 0; ow < outWidth; ow++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(oh * strideHeight),
rewriter.getIndexAttr(ow * strideWidth)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
// until the late PIM unrolling step.
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
im2colRows.push_back(row);
}
}
}
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
rewriter.setInsertionPointToStart(im2colLoop.getBody());
// Concatenate all rows: [numPatches, patchSize]
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
Value patchIndex = im2colLoop.getInductionVar();
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value updatedIm2col =
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
scf::YieldOp::create(rewriter, loc, updatedIm2col);
rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
});

View File

@@ -12,6 +12,7 @@ add_pim_library(OMSpatialToPim
SpatialToPimIncGen
LINK_LIBS PUBLIC
MLIRSCFDialect
MLIRTosaDialect
OMCompilerOptions
OMPimCommon

View File

@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -134,7 +135,12 @@ void SpatialToPimPass::runOnOperation() {
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
target.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
func::FuncDialect,
scf::SCFDialect,
BuiltinDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);

View File

@@ -13,6 +13,7 @@ add_pim_library(OMPimPasses
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
OMCompilerUtils
OMPimCommon
)

View File

@@ -85,12 +85,8 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
info.offsets.push_back(*staticOffset);
}
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
@@ -106,14 +102,16 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
return info;
}
int64_t
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(info.sourceShape.size());
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
sourceIndices.push_back(info.offsets.back());
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(info.offsets.size());
for (OpFoldResult offset : info.offsets) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
staticOffsets.push_back(*staticOffset);
}
return staticOffsets;
}
} // namespace onnx_mlir

View File

@@ -14,7 +14,7 @@ namespace onnx_mlir {
struct StaticSubviewInfo {
mlir::Value source;
llvm::SmallVector<int64_t> sourceShape;
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<mlir::OpFoldResult> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
};
@@ -34,8 +34,7 @@ llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp modu
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
llvm::ArrayRef<int64_t> outerIndices,
int64_t elementByteWidth);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
} // namespace onnx_mlir

View File

@@ -120,7 +120,15 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
rewriter.setInsertionPoint(mapOp);
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
pim::PimMemCopyOp::create(rewriter,
mapOp.getLoc(),
initType,
mapOp.getInit(),
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes));
rewriter.eraseOp(mapOp);
return success();
}
@@ -416,6 +424,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
return failure();
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
if (failed(staticOffsets))
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
@@ -428,7 +439,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
sourceIndices.push_back(off + idx);
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
resultValues[i] = sourceValues[srcLinear];

View File

@@ -1,3 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "../Common.hpp"
#include "../Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -8,6 +10,62 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isSubviewContiguous(const StaticSubviewInfo& info) {
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
return false;
auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()),
llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize == sizesAndShape.end())
return true;
++firstDifferentSize;
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
auto [size, _dimension] = sizeAndShape;
return size == 1;
});
}
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
if (extraOffset == 0)
return baseOffset;
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
assert(integerAttr && "expected integer offset attribute");
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
}
auto value = cast<Value>(baseOffset);
auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
}
static Value buildSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
@@ -19,12 +77,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc =
succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst =
succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview);
const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview);
if (!splitSrc && !splitDst)
return failure();
@@ -35,9 +89,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
@@ -64,18 +118,11 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset =
srcOffset
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
const int64_t dstByteOffset =
dstOffset
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : dst,
splitSrc ? srcSubview->source : src,
dstByteOffset,
srcByteOffset,
sliceBytes);
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
}
return success();
@@ -198,6 +245,9 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
return failure();
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
return failure();
auto staticOffsets = getStaticSubviewOffsets(*subviewInfo);
if (failed(staticOffsets))
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
@@ -217,7 +267,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
sourceIndices.push_back(off + idx);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}

View File

@@ -132,38 +132,37 @@ private:
}
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
bool hasFailure = false;
for (Operation& op : coreOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op))
continue;
return walkPimCoreBlock(
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
bool hasFailure = false;
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<BaseMemRefType>(operand.getType()))
continue;
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
auto resolvedAddress = resolveContiguousAddress(operand);
if (failed(resolvedAddress)) {
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
hasFailure = true;
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex
<< " is not backed by contiguous addressable storage";
hasFailure = true;
}
continue;
}
if (isExplicitHostOperand(&op, operandIndex)) {
if (!isCodegenAddressableValue(operand)) {
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
continue;
}
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
op.emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
hasFailure = true;
}
}
}
return success(!hasFailure);
return success(!hasFailure);
});
}
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {

View File

@@ -3,6 +3,8 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
@@ -57,12 +59,14 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<mlir::tosa::TosaDialect>();
registry.insert<mlir::bufferization::BufferizationDialect>();
registry.insert<mlir::scf::SCFDialect>();
registry.insert<pim::PimDialect>();
registry.insert<spatial::SpatialDialect>();
mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry);
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerOpBufferizationInterfaces(registry);

View File

@@ -3,3 +3,8 @@ operations/**/outputs
operations/**/raptor
operations/**/runner
operations/**/simulation
networks/**/inputs
networks/**/outputs
networks/**/raptor
networks/**/runner
networks/**/simulation