reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -54,6 +54,8 @@ add_pim_library(OMPIMAccel
|
||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRSCFDialect
|
||||
MLIRSCFTransforms
|
||||
onnx
|
||||
OMAccelerator
|
||||
OMPimCompilerUtils
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
@@ -8,6 +10,7 @@
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -240,8 +243,129 @@ bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||
return true;
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
|
||||
if (!knowledge)
|
||||
return value;
|
||||
|
||||
auto iter = knowledge->aliases.find(value);
|
||||
while (iter != knowledge->aliases.end()) {
|
||||
value = iter->second;
|
||||
iter = knowledge->aliases.find(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
|
||||
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
|
||||
// and when propagating yielded values across iterations during static unrolling.
|
||||
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return value;
|
||||
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
if (auto result = dyn_cast<OpResult>(value))
|
||||
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (knowledge) {
|
||||
auto iter = knowledge->indexValues.find(value);
|
||||
if (iter != knowledge->indexValues.end())
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||
if (constantOp) {
|
||||
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
|
||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||
|
||||
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs + *rhs;
|
||||
}
|
||||
|
||||
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs - *rhs;
|
||||
}
|
||||
|
||||
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs * *rhs;
|
||||
}
|
||||
|
||||
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!integerAttr)
|
||||
return failure();
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
|
||||
}
|
||||
|
||||
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
int64_t byteOffset = 0;
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
while (true) {
|
||||
if (isa<BlockArgument>(value))
|
||||
@@ -255,7 +379,29 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return failure();
|
||||
value = tiedOperand->get();
|
||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result)
|
||||
return failure();
|
||||
|
||||
// Trace the loop carry back to its underlying memref, then if that memref is the
|
||||
// loop's own iter-arg we know the base comes from the corresponding init arg
|
||||
// (every iteration yields the same backing memory in the DPS sense).
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -265,31 +411,53 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = subviewOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = subviewOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = subviewOp.getStaticStrides();
|
||||
if (llvm::is_contained(offsets, ShapedType::kDynamic) || llvm::is_contained(sizes, ShapedType::kDynamic)
|
||||
|| llvm::is_contained(strides, ShapedType::kDynamic))
|
||||
return failure();
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> sizes;
|
||||
SmallVector<int64_t> strides;
|
||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||
strides.reserve(subviewOp.getMixedStrides().size());
|
||||
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return failure();
|
||||
offsets.push_back(*resolvedOffset);
|
||||
}
|
||||
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||
if (failed(resolvedSize))
|
||||
return failure();
|
||||
sizes.push_back(*resolvedSize);
|
||||
}
|
||||
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||
if (failed(resolvedStride))
|
||||
return failure();
|
||||
strides.push_back(*resolvedStride);
|
||||
}
|
||||
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
value = subviewOp.getSource();
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
value = resolveAlias(castOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -300,4 +468,79 @@ FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||
|
||||
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
return resolveContiguousAddressImpl(value, nullptr);
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveContiguousAddressImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
bool isCoreStaticAddressOp(Operation* op) {
|
||||
return isa<arith::ConstantOp,
|
||||
arith::AddIOp,
|
||||
arith::SubIOp,
|
||||
arith::MulIOp,
|
||||
arith::DivUIOp,
|
||||
arith::RemUIOp,
|
||||
arith::IndexCastOp,
|
||||
memref::AllocOp,
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
LogicalResult walkPimCoreBlock(Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (Operation& op : block) {
|
||||
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
|
||||
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
|
||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
@@ -20,6 +21,13 @@ struct ResolvedContiguousAddress {
|
||||
int64_t byteOffset = 0;
|
||||
};
|
||||
|
||||
struct StaticValueKnowledge {
|
||||
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
||||
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
||||
|
||||
StaticValueKnowledge() {}
|
||||
};
|
||||
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
@@ -52,5 +60,29 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge);
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
|
||||
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
|
||||
/// only contribute to static addressing or index computations (arith integer math,
|
||||
/// memref view ops, memref.alloc, arith.constant).
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op);
|
||||
|
||||
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
|
||||
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
|
||||
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
|
||||
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
|
||||
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -84,8 +84,8 @@ PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
||||
}
|
||||
|
||||
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
auto resolvedAddress = resolveContiguousAddress(value);
|
||||
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||
if (failed(resolvedAddress)) {
|
||||
errs() << "Failed to resolve contiguous address for value: ";
|
||||
value.print(errs());
|
||||
@@ -199,47 +199,49 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
|
||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitMemCopyOp("ld",
|
||||
memory.getValueAddress(loadOp.getDeviceTarget()),
|
||||
addressOf(loadOp.getDeviceTarget(), knowledge),
|
||||
loadOp.getDeviceTargetOffset(),
|
||||
memory.getValueAddress(loadOp.getHostSource()),
|
||||
addressOf(loadOp.getHostSource(), knowledge),
|
||||
loadOp.getHostSourceOffset(),
|
||||
loadOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
|
||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitMemCopyOp("st",
|
||||
memory.getValueAddress(storeOp.getHostTarget()),
|
||||
addressOf(storeOp.getHostTarget(), knowledge),
|
||||
storeOp.getHostTargetOffset(),
|
||||
memory.getValueAddress(storeOp.getDeviceSource()),
|
||||
addressOf(storeOp.getDeviceSource(), knowledge),
|
||||
storeOp.getDeviceSourceOffset(),
|
||||
storeOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
|
||||
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitMemCopyOp("lmv",
|
||||
memory.getValueAddress(lmvOp.getTarget()),
|
||||
addressOf(lmvOp.getTarget(), knowledge),
|
||||
lmvOp.getTargetOffset(),
|
||||
memory.getValueAddress(lmvOp.getSource()),
|
||||
addressOf(lmvOp.getSource(), knowledge),
|
||||
lmvOp.getSourceOffset(),
|
||||
lmvOp.getSize(),
|
||||
"len");
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
|
||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp(
|
||||
"recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
|
||||
emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||
}
|
||||
|
||||
template <typename MVMTy>
|
||||
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
|
||||
emitMvmOp(
|
||||
mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0);
|
||||
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
|
||||
MVMTy mvmLikeOp,
|
||||
bool transposeMatrix,
|
||||
const StaticValueKnowledge& knowledge) {
|
||||
emitMvmOp(mvmId, addressOf(mvmLikeOp.getOutputBuffer(), knowledge), 0, addressOf(mvmLikeOp.getInput(), knowledge), 0);
|
||||
|
||||
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
||||
}
|
||||
@@ -249,10 +251,10 @@ static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer());
|
||||
auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs());
|
||||
auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs());
|
||||
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
|
||||
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
|
||||
auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -265,10 +267,10 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer());
|
||||
auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs());
|
||||
auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs());
|
||||
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vvsubOp.getOutputBuffer(), knowledge);
|
||||
auto lhsAddr = addressOf(vvsubOp.getLhs(), knowledge);
|
||||
auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -281,10 +283,10 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer());
|
||||
auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs());
|
||||
auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs());
|
||||
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vvmulOp.getOutputBuffer(), knowledge);
|
||||
auto lhsAddr = addressOf(vvmulOp.getLhs(), knowledge);
|
||||
auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -297,10 +299,10 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer());
|
||||
auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs());
|
||||
auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs());
|
||||
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vvmaxOp.getOutputBuffer(), knowledge);
|
||||
auto lhsAddr = addressOf(vvmaxOp.getLhs(), knowledge);
|
||||
auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -313,10 +315,10 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer());
|
||||
auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs());
|
||||
auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs());
|
||||
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vvdmulOp.getOutputBuffer(), knowledge);
|
||||
auto lhsAddr = addressOf(vvdmulOp.getLhs(), knowledge);
|
||||
auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -329,9 +331,9 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer());
|
||||
auto inputAddr = memory.getValueAddress(vavgOp.getInput());
|
||||
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vavgOp.getOutputBuffer(), knowledge);
|
||||
auto inputAddr = addressOf(vavgOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -344,9 +346,9 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer());
|
||||
auto inputAddr = memory.getValueAddress(vreluOp.getInput());
|
||||
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vreluOp.getOutputBuffer(), knowledge);
|
||||
auto inputAddr = addressOf(vreluOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -358,9 +360,9 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer());
|
||||
auto inputAddr = memory.getValueAddress(vtanhOp.getInput());
|
||||
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vtanhOp.getOutputBuffer(), knowledge);
|
||||
auto inputAddr = addressOf(vtanhOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -372,9 +374,9 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer());
|
||||
auto inputAddr = memory.getValueAddress(vsigmOp.getInput());
|
||||
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vsigmOp.getOutputBuffer(), knowledge);
|
||||
auto inputAddr = addressOf(vsigmOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -386,9 +388,9 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vsoftmaxOp.getOutputBuffer());
|
||||
auto inputAddr = memory.getValueAddress(vsoftmaxOp.getInput());
|
||||
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto outputBufferAddr = addressOf(vsoftmaxOp.getOutputBuffer(), knowledge);
|
||||
auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
@@ -400,9 +402,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||
auto srcAddr = memory.getValueAddress(transposeOp.getInput());
|
||||
auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer());
|
||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
|
||||
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
|
||||
|
||||
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
||||
auto srcShape = srcType.getShape();
|
||||
@@ -510,57 +512,58 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
||||
}
|
||||
|
||||
/// Dispatch all operations in a core region to the appropriate code generator.
|
||||
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
|
||||
/// fully resolved before the JSON instructions are emitted.
|
||||
/// Returns the number of emitted instructions, or -1 on failure.
|
||||
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
size_t processedOperations = 0;
|
||||
for (auto& op : coreOp.getBody().front()) {
|
||||
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
|
||||
coreCodeGen.codeGenLoadOp(loadOp);
|
||||
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
|
||||
coreCodeGen.codeGenStoreOp(storeOp);
|
||||
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
|
||||
coreCodeGen.codeGenLmvOp(lmvOp);
|
||||
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
||||
coreCodeGen.codeGenReceiveOp(receiveOp);
|
||||
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
||||
coreCodeGen.codeGenSendOp(sendOp);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
|
||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
coreCodeGen.codeGenVVAddOp(vvaddOp);
|
||||
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
|
||||
coreCodeGen.codeGenVVSubOp(vvsubOp);
|
||||
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
|
||||
coreCodeGen.codeGenVVMulOp(vvmulOp);
|
||||
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
|
||||
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
|
||||
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
|
||||
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
|
||||
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
|
||||
coreCodeGen.codeGenVAvgOp(vavgOp);
|
||||
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
||||
coreCodeGen.codeGenVReluOp(vreluOp);
|
||||
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
|
||||
coreCodeGen.codeGenVTanhOp(vtanhOp);
|
||||
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
||||
coreCodeGen.codeGenVSigmOp(vsigmOp);
|
||||
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
||||
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp);
|
||||
else {
|
||||
op.emitError("Unsupported codegen for this operation");
|
||||
op.dump();
|
||||
return -1;
|
||||
}
|
||||
processedOperations++;
|
||||
}
|
||||
return processedOperations;
|
||||
auto result =
|
||||
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
|
||||
coreCodeGen.codeGenLoadOp(loadOp, knowledge);
|
||||
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
|
||||
coreCodeGen.codeGenStoreOp(storeOp, knowledge);
|
||||
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
|
||||
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
|
||||
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
||||
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
|
||||
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
||||
coreCodeGen.codeGenSendOp(sendOp, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
coreCodeGen.codeGenVVAddOp(vvaddOp, knowledge);
|
||||
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
|
||||
coreCodeGen.codeGenVVSubOp(vvsubOp, knowledge);
|
||||
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
|
||||
coreCodeGen.codeGenVVMulOp(vvmulOp, knowledge);
|
||||
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
|
||||
coreCodeGen.codeGenVVMaxOp(vvmaxOp, knowledge);
|
||||
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
|
||||
coreCodeGen.codeGenVVDMulOp(vvdmulOp, knowledge);
|
||||
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
|
||||
coreCodeGen.codeGenVAvgOp(vavgOp, knowledge);
|
||||
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
||||
coreCodeGen.codeGenVReluOp(vreluOp, knowledge);
|
||||
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
|
||||
coreCodeGen.codeGenVTanhOp(vtanhOp, knowledge);
|
||||
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
||||
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
|
||||
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
||||
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
||||
else {
|
||||
op.emitError("Unsupported codegen for this operation");
|
||||
op.dump();
|
||||
return failure();
|
||||
}
|
||||
processedOperations++;
|
||||
return success();
|
||||
});
|
||||
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
||||
}
|
||||
|
||||
/// Write crossbar weight matrices as padded binary files for a single core.
|
||||
@@ -739,7 +742,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
PimCodeGen coreCodeGen(memory, coreFileStream);
|
||||
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
||||
|
||||
int64_t processedOperations = codeGenCoreOps(coreOp, coreCodeGen);
|
||||
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
||||
if (processedOperations < 0)
|
||||
return CompilerFailure;
|
||||
assert(processedOperations > 0);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "llvm/Support/JSON.h"
|
||||
|
||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -50,13 +51,17 @@ public:
|
||||
|
||||
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||
|
||||
size_t getValueAddress(mlir::Value value) const;
|
||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||
};
|
||||
|
||||
class PimCodeGen {
|
||||
PimAcceleratorMemory& memory;
|
||||
llvm::raw_fd_ostream& coreFileStream;
|
||||
|
||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
return memory.getValueAddress(value, knowledge);
|
||||
}
|
||||
|
||||
static llvm::json::Object createEmptyOffset();
|
||||
void emitInstruction(llvm::json::Object instruction) const;
|
||||
|
||||
@@ -80,27 +85,27 @@ public:
|
||||
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
||||
: memory(memory), coreFileStream(coreJson) {}
|
||||
|
||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
|
||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
|
||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp) const;
|
||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
|
||||
void codeGenSendOp(pim::PimSendOp sendOp) const;
|
||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
template <typename MVMTy>
|
||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
|
||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
|
||||
|
||||
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const;
|
||||
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const;
|
||||
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const;
|
||||
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const;
|
||||
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const;
|
||||
void codeGenVAvgOp(pim::PimVAvgOp vavgOp) const;
|
||||
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
|
||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
|
||||
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const;
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
||||
void codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
||||
};
|
||||
|
||||
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
|
||||
@@ -26,6 +26,7 @@ add_pim_library(OMONNXToSpatial
|
||||
ONNXToSpatialIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRSCFDialect
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCompilerOptions
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
@@ -75,7 +76,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
}
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||
target.addIllegalOp<ONNXAddOp>();
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
@@ -169,44 +170,60 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
|
||||
// Build im2col [numPatches, patchSize]:
|
||||
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||
SmallVector<Value> im2colRows;
|
||||
im2colRows.reserve(numPatches);
|
||||
for (int64_t n = 0; n < batchSize; n++) {
|
||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(oh * strideHeight),
|
||||
rewriter.getIndexAttr(ow * strideWidth)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
|
||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
im2colRows.push_back(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
|
||||
// Concatenate all rows: [numPatches, patchSize]
|
||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||
Value patchIndex = im2colLoop.getInductionVar();
|
||||
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
||||
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
|
||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value updatedIm2col =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
});
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ add_pim_library(OMSpatialToPim
|
||||
SpatialToPimIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRSCFDialect
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCommon
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
@@ -134,7 +135,12 @@ void SpatialToPimPass::runOnOperation() {
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
|
||||
target.addLegalDialect<PimDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
func::FuncDialect,
|
||||
scf::SCFDialect,
|
||||
BuiltinDialect>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
@@ -13,6 +13,7 @@ add_pim_library(OMPimPasses
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
MLIRSCFDialect
|
||||
OMCompilerUtils
|
||||
OMPimCommon
|
||||
)
|
||||
|
||||
@@ -85,12 +85,8 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
info.offsets.push_back(*staticOffset);
|
||||
}
|
||||
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
@@ -106,14 +102,16 @@ FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
return info;
|
||||
}
|
||||
|
||||
int64_t
|
||||
getSubviewChunkOffsetBytes(const StaticSubviewInfo& info, ArrayRef<int64_t> outerIndices, int64_t elementByteWidth) {
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(info.sourceShape.size());
|
||||
for (size_t dim = 0; dim + 1 < info.sourceShape.size(); ++dim)
|
||||
sourceIndices.push_back(info.offsets[dim] + outerIndices[dim] * info.strides[dim]);
|
||||
sourceIndices.push_back(info.offsets.back());
|
||||
return linearizeIndex(sourceIndices, computeRowMajorStrides(info.sourceShape)) * elementByteWidth;
|
||||
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(info.offsets.size());
|
||||
for (OpFoldResult offset : info.offsets) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
staticOffsets.push_back(*staticOffset);
|
||||
}
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace onnx_mlir {
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<int64_t> offsets;
|
||||
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
@@ -34,8 +34,7 @@ llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp modu
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
int64_t getSubviewChunkOffsetBytes(const StaticSubviewInfo& info,
|
||||
llvm::ArrayRef<int64_t> outerIndices,
|
||||
int64_t elementByteWidth);
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -120,7 +120,15 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
rewriter.replaceAllUsesExcept(mapOp.getInit(), getGlobalOp.getResult(), mapOp);
|
||||
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
||||
pim::PimMemCopyOp::create(rewriter,
|
||||
mapOp.getLoc(),
|
||||
initType,
|
||||
mapOp.getInit(),
|
||||
getGlobalOp.getResult(),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||
rewriter.eraseOp(mapOp);
|
||||
return success();
|
||||
}
|
||||
@@ -416,6 +424,9 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
return failure();
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
@@ -428,7 +439,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
|
||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
||||
resultValues[i] = sourceValues[srcLinear];
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
@@ -8,6 +10,62 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool isSubviewContiguous(const StaticSubviewInfo& info) {
|
||||
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()),
|
||||
llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend()));
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
if (firstDifferentSize == sizesAndShape.end())
|
||||
return true;
|
||||
|
||||
++firstDifferentSize;
|
||||
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size == 1;
|
||||
});
|
||||
}
|
||||
|
||||
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
||||
if (extraOffset == 0)
|
||||
return baseOffset;
|
||||
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(integerAttr && "expected integer offset attribute");
|
||||
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
||||
}
|
||||
|
||||
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<int64_t> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
||||
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
@@ -19,12 +77,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
auto dstSubview = getStaticSubviewInfo(dst);
|
||||
const bool splitSrc =
|
||||
succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst =
|
||||
succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview);
|
||||
const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
@@ -35,9 +89,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
@@ -64,18 +118,11 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset =
|
||||
srcOffset
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset =
|
||||
dstOffset
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) : linearIndex * sliceBytes);
|
||||
createCopyOp(splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : dst,
|
||||
splitSrc ? srcSubview->source : src,
|
||||
dstByteOffset,
|
||||
srcByteOffset,
|
||||
sliceBytes);
|
||||
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
||||
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
||||
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
||||
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
||||
}
|
||||
|
||||
return success();
|
||||
@@ -198,6 +245,9 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
auto staticOffsets = getStaticSubviewOffsets(*subviewInfo);
|
||||
if (failed(staticOffsets))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
@@ -217,7 +267,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||
for (auto [off, idx] : llvm::zip_equal(*staticOffsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
|
||||
@@ -132,38 +132,37 @@ private:
|
||||
}
|
||||
|
||||
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
|
||||
bool hasFailure = false;
|
||||
for (Operation& op : coreOp.getBody().front()) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
return walkPimCoreBlock(
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
bool hasFailure = false;
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return success(!hasFailure);
|
||||
return success(!hasFailure);
|
||||
});
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Func/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
@@ -57,12 +59,14 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
||||
registry.insert<mlir::tensor::TensorDialect>();
|
||||
registry.insert<mlir::tosa::TosaDialect>();
|
||||
registry.insert<mlir::bufferization::BufferizationDialect>();
|
||||
registry.insert<mlir::scf::SCFDialect>();
|
||||
registry.insert<pim::PimDialect>();
|
||||
registry.insert<spatial::SpatialDialect>();
|
||||
mlir::tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::tensor::registerInferTypeOpInterfaceExternalModels(registry);
|
||||
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
spatial::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
|
||||
pim::registerOpBufferizationInterfaces(registry);
|
||||
|
||||
5
validation/.gitignore
vendored
5
validation/.gitignore
vendored
@@ -3,3 +3,8 @@ operations/**/outputs
|
||||
operations/**/raptor
|
||||
operations/**/runner
|
||||
operations/**/simulation
|
||||
networks/**/inputs
|
||||
networks/**/outputs
|
||||
networks/**/raptor
|
||||
networks/**/runner
|
||||
networks/**/simulation
|
||||
|
||||
Reference in New Issue
Block a user