centralize logic for materializing contiguous memory into bufferization
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op
This commit is contained in:
@@ -616,31 +616,38 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
|
|||||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> staticOffsets;
|
|
||||||
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
|
|
||||||
llvm::SmallVector<int64_t> staticSizes;
|
llvm::SmallVector<int64_t> staticSizes;
|
||||||
staticSizes.reserve(subviewOp.getMixedSizes().size());
|
staticSizes.reserve(subviewOp.getMixedSizes().size());
|
||||||
llvm::SmallVector<int64_t> staticStrides;
|
llvm::SmallVector<int64_t> staticStrides;
|
||||||
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
||||||
bool allStatic = true;
|
llvm::SmallVector<int64_t> staticOffsets;
|
||||||
|
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
|
||||||
|
bool hasOnlyStaticOffsets = true;
|
||||||
|
|
||||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
|
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
|
||||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||||
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||||
else
|
else
|
||||||
allStatic = false;
|
hasOnlyStaticOffsets = false;
|
||||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes())
|
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
|
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
|
||||||
|
if (!attr)
|
||||||
|
return mlir::failure();
|
||||||
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||||
else
|
}
|
||||||
allStatic = false;
|
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides())
|
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
|
||||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
|
if (!attr)
|
||||||
|
return mlir::failure();
|
||||||
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||||
else
|
}
|
||||||
allStatic = false;
|
|
||||||
|
|
||||||
if (allStatic) {
|
if (!isContiguousSubviewWithDynamicOffsets(
|
||||||
|
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
|
||||||
|
return mlir::failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hasOnlyStaticOffsets) {
|
||||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds,
|
|||||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
|
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
|
||||||
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
|
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
|
||||||
return operandIndex == 3;
|
return operandIndex == 3;
|
||||||
if (mlir::isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
|
||||||
return operandIndex == 1;
|
|
||||||
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
|
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
return operandIndex == 2;
|
return operandIndex == 2;
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -111,4 +111,56 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
||||||
|
llvm::ArrayRef<int64_t> staticSizes,
|
||||||
|
llvm::ArrayRef<int64_t> staticStrides) {
|
||||||
|
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|
||||||
|
|| sourceShape.size() != staticStrides.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto reversedTriples =
|
||||||
|
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
|
||||||
|
|
||||||
|
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
|
||||||
|
auto [_sourceDim, offset, _size] = triple;
|
||||||
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||||
|
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
|
||||||
|
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
|
||||||
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
||||||
|
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
|
||||||
|
if (size > sourceDim - staticOffset)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
++firstNonZeroOrDynamicOffset;
|
||||||
|
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
|
||||||
|
if (std::get<2>(*it) != 1)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
|
||||||
|
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
|
||||||
|
auto [sourceDim, size] = pair;
|
||||||
|
return size != sourceDim;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstDifferentSize != reversedSizes.end()) {
|
||||||
|
++firstDifferentSize;
|
||||||
|
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
|
||||||
|
if (std::get<1>(*it) != 1)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
@@ -30,4 +31,9 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|||||||
llvm::ArrayRef<int64_t> sizes,
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
llvm::ArrayRef<int64_t> strides);
|
llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
||||||
|
llvm::ArrayRef<int64_t> staticSizes,
|
||||||
|
llvm::ArrayRef<int64_t> staticStrides);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -31,6 +31,19 @@ Value stripMemRefViewOps(Value value) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value stripMemRefAddressingOps(Value value) {
|
||||||
|
while (true) {
|
||||||
|
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
|
||||||
|
value = subviewOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Value strippedValue = stripMemRefViewOps(value);
|
||||||
|
if (strippedValue == value)
|
||||||
|
return value;
|
||||||
|
value = strippedValue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
||||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||||
@@ -81,4 +94,13 @@ FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo&
|
|||||||
return staticOffsets;
|
return staticOffsets;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isMemRefBaseAddressableValue(Value value) {
|
||||||
|
value = stripMemRefAddressingOps(value);
|
||||||
|
if (isa<BlockArgument>(value))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
Operation* defOp = value.getDefiningOp();
|
||||||
|
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ mlir::Value stripMemRefCasts(mlir::Value value);
|
|||||||
|
|
||||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||||
|
|
||||||
|
mlir::Value stripMemRefAddressingOps(mlir::Value value);
|
||||||
|
|
||||||
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
||||||
|
|
||||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||||
@@ -27,4 +29,6 @@ llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
|||||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||||
|
|
||||||
|
bool isMemRefBaseAddressableValue(mlir::Value value);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -47,28 +47,6 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
|||||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Value stripWeightViewOps(mlir::Value value) {
|
|
||||||
while (true) {
|
|
||||||
if (auto subviewOp = value.getDefiningOp<mlir::memref::SubViewOp>()) {
|
|
||||||
value = subviewOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto castOp = value.getDefiningOp<mlir::memref::CastOp>()) {
|
|
||||||
value = castOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseOp = value.getDefiningOp<mlir::memref::CollapseShapeOp>()) {
|
|
||||||
value = collapseOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandOp = value.getDefiningOp<mlir::memref::ExpandShapeOp>()) {
|
|
||||||
value = expandOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename VMMOpTy, typename ParentOpTy>
|
template <typename VMMOpTy, typename ParentOpTy>
|
||||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||||
@@ -159,7 +137,7 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
|
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
|
||||||
weight = stripWeightViewOps(weight);
|
weight = stripMemRefAddressingOps(weight);
|
||||||
|
|
||||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||||
|
|||||||
@@ -479,16 +479,6 @@ void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticVa
|
|||||||
loadOp.getSize());
|
loadOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp,
|
|
||||||
const StaticValueKnowledge& knowledge) const {
|
|
||||||
emitMemCopyOp("ld",
|
|
||||||
addressOf(loadOp.getDeviceTarget(), knowledge),
|
|
||||||
loadOp.getDeviceTargetOffset(),
|
|
||||||
addressOf(loadOp.getHostSource(), knowledge),
|
|
||||||
loadOp.getHostSourceOffset(),
|
|
||||||
loadOp.getSize());
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
||||||
auto hostTargetOffset = indexOf(storeOp.getHostTargetOffset(), knowledge);
|
auto hostTargetOffset = indexOf(storeOp.getHostTargetOffset(), knowledge);
|
||||||
auto deviceSourceOffset = indexOf(storeOp.getDeviceSourceOffset(), knowledge);
|
auto deviceSourceOffset = indexOf(storeOp.getDeviceSourceOffset(), knowledge);
|
||||||
@@ -848,7 +838,6 @@ public:
|
|||||||
|
|
||||||
enum class CompiledCoreOpKind : uint8_t {
|
enum class CompiledCoreOpKind : uint8_t {
|
||||||
Load,
|
Load,
|
||||||
LoadBatch,
|
|
||||||
Store,
|
Store,
|
||||||
Lmv,
|
Lmv,
|
||||||
Receive,
|
Receive,
|
||||||
@@ -887,8 +876,6 @@ struct CompiledCoreNode {
|
|||||||
static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
|
static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
|
||||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||||
return CompiledCoreOpKind::Load;
|
return CompiledCoreOpKind::Load;
|
||||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
|
||||||
return CompiledCoreOpKind::LoadBatch;
|
|
||||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
return CompiledCoreOpKind::Store;
|
return CompiledCoreOpKind::Store;
|
||||||
if (isa<pim::PimMemCopyOp>(op))
|
if (isa<pim::PimMemCopyOp>(op))
|
||||||
@@ -1027,9 +1014,6 @@ static LogicalResult executeCompiledCorePlan(
|
|||||||
case CompiledCoreOpKind::Load:
|
case CompiledCoreOpKind::Load:
|
||||||
coreCodeGen.codeGenLoadOp(cast<pim::PimMemCopyHostToDevOp>(node.op), knowledge);
|
coreCodeGen.codeGenLoadOp(cast<pim::PimMemCopyHostToDevOp>(node.op), knowledge);
|
||||||
break;
|
break;
|
||||||
case CompiledCoreOpKind::LoadBatch:
|
|
||||||
coreCodeGen.codeGenLoadBatchOp(cast<pim::PimMemCopyHostToDevBatchOp>(node.op), knowledge);
|
|
||||||
break;
|
|
||||||
case CompiledCoreOpKind::Store:
|
case CompiledCoreOpKind::Store:
|
||||||
coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge);
|
coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge);
|
||||||
break;
|
break;
|
||||||
@@ -1213,17 +1197,18 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
auto linkCoreWeights =
|
auto linkCoreWeights =
|
||||||
[&](size_t coreId, ArrayRef<std::string> weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
|
[&](size_t coreId, ArrayRef<std::string> weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
|
||||||
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
||||||
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
if (auto error = sys::fs::create_directory(coreWeightsDirPath); error && error != std::errc::file_exists) {
|
||||||
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
||||||
return InvalidOutputFileAccess;
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto [slot, fileName] : llvm::enumerate(weightFiles)) {
|
for (auto [slot, fileName] : llvm::enumerate(weightFiles)) {
|
||||||
xbarsPerGroup.push_back(static_cast<int64_t>(slot));
|
xbarsPerGroup.push_back(static_cast<int64_t>(slot));
|
||||||
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
std::string sourcePath = outputDirPath + "/weights/" + fileName;
|
||||||
coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) {
|
std::string targetPath = coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin";
|
||||||
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
sys::fs::remove(targetPath);
|
||||||
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin") << "\nError:" << error.message()
|
if (auto error = sys::fs::create_link(sourcePath, targetPath)) {
|
||||||
|
errs() << "Error creating link file: " << sourcePath << " to " << targetPath << "\nError:" << error.message()
|
||||||
<< '\n';
|
<< '\n';
|
||||||
return InvalidOutputFileAccess;
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -175,7 +175,6 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const;
|
|
||||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "IndexingUtils.hpp"
|
#include "IndexingUtils.hpp"
|
||||||
@@ -20,35 +17,6 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
APInt lhsConst;
|
|
||||||
if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero())
|
|
||||||
return rhs;
|
|
||||||
|
|
||||||
APInt rhsConst;
|
|
||||||
if (matchPattern(rhs, m_ConstantInt(&rhsConst)) && rhsConst.isZero())
|
|
||||||
return lhs;
|
|
||||||
|
|
||||||
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
APInt factorConst;
|
|
||||||
if (auto attr = dyn_cast<Attribute>(factor))
|
|
||||||
factorConst = cast<IntegerAttr>(attr).getValue();
|
|
||||||
else if (!matchPattern(cast<Value>(factor), m_ConstantInt(&factorConst)))
|
|
||||||
return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult();
|
|
||||||
|
|
||||||
if (factorConst.isZero())
|
|
||||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
||||||
if (factorConst.isOne())
|
|
||||||
return value;
|
|
||||||
|
|
||||||
auto factorValue =
|
|
||||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), factorConst.getSExtValue());
|
|
||||||
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
}
|
}
|
||||||
@@ -124,39 +92,6 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
|
|||||||
return sizes;
|
return sizes;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
|
||||||
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
|
||||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()
|
|
||||||
|| sourceType.getRank() != resultType.getRank())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
for (OpFoldResult stride : strides) {
|
|
||||||
APInt strideValue;
|
|
||||||
if (auto attr = dyn_cast<Attribute>(stride)) {
|
|
||||||
if (cast<IntegerAttr>(attr).getInt() != 1)
|
|
||||||
return false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!matchPattern(cast<Value>(stride), m_ConstantInt(&strideValue)) || !strideValue.isOne())
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(resultType.getShape().rbegin(), resultType.getShape().rend()),
|
|
||||||
llvm::make_range(sourceType.getShape().rbegin(), sourceType.getShape().rend()));
|
|
||||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, dimension] = sizeAndShape;
|
|
||||||
return size != dimension;
|
|
||||||
});
|
|
||||||
if (firstDifferentSize == sizesAndShape.end())
|
|
||||||
return true;
|
|
||||||
|
|
||||||
++firstDifferentSize;
|
|
||||||
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
|
|
||||||
auto [size, _dimension] = sizeAndShape;
|
|
||||||
return size == 1;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> sliceTensor(
|
SmallVector<Value> sliceTensor(
|
||||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||||
@@ -222,90 +157,6 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri
|
|||||||
return slicesPerCore;
|
return slicesPerCore;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value materializeContiguousTensorSlice(Value source,
|
|
||||||
RankedTensorType resultType,
|
|
||||||
ArrayRef<OpFoldResult> offsets,
|
|
||||||
ArrayRef<OpFoldResult> strides,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
assert(resultType.hasStaticShape() && "expected static result type");
|
|
||||||
size_t rank = static_cast<size_t>(resultType.getRank());
|
|
||||||
assert(offsets.size() == rank && "expected rank-matching offsets");
|
|
||||||
assert(strides.size() == rank && "expected rank-matching strides");
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> sizes;
|
|
||||||
sizes.reserve(resultType.getRank());
|
|
||||||
for (int64_t size : resultType.getShape())
|
|
||||||
sizes.push_back(rewriter.getIndexAttr(size));
|
|
||||||
|
|
||||||
if (isContiguousTensorSlice(source, resultType, strides))
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
|
||||||
|
|
||||||
if (resultType.getRank() == 0)
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
|
||||||
|
|
||||||
Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
|
|
||||||
SmallVector<Value> zeroIndices(resultType.getRank());
|
|
||||||
for (Value& zeroIndex : zeroIndices)
|
|
||||||
zeroIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
|
||||||
|
|
||||||
SmallVector<Value> resultIndices;
|
|
||||||
resultIndices.reserve(resultType.getRank());
|
|
||||||
|
|
||||||
auto buildLoopNest = [&](auto&& self, unsigned dim, Value accumulator) -> Value {
|
|
||||||
if (dim == resultType.getRank()) {
|
|
||||||
SmallVector<Value> sourceIndices;
|
|
||||||
sourceIndices.reserve(resultType.getRank());
|
|
||||||
for (unsigned idx = 0; idx < resultType.getRank(); ++idx) {
|
|
||||||
Value offsetValue = getOrMaterializeIndexValue(rewriter, offsets[idx]);
|
|
||||||
Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc);
|
|
||||||
sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc));
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> sourceOffsets;
|
|
||||||
SmallVector<OpFoldResult> destinationOffsets;
|
|
||||||
SmallVector<OpFoldResult> unitSizes;
|
|
||||||
SmallVector<OpFoldResult> unitStrides;
|
|
||||||
sourceOffsets.reserve(resultType.getRank());
|
|
||||||
destinationOffsets.reserve(resultType.getRank());
|
|
||||||
unitSizes.reserve(resultType.getRank());
|
|
||||||
unitStrides.reserve(resultType.getRank());
|
|
||||||
for (Value index : sourceIndices)
|
|
||||||
sourceOffsets.push_back(index);
|
|
||||||
for (Value index : resultIndices)
|
|
||||||
destinationOffsets.push_back(index);
|
|
||||||
for (int64_t idx = 0; idx < resultType.getRank(); ++idx) {
|
|
||||||
unitSizes.push_back(rewriter.getIndexAttr(1));
|
|
||||||
unitStrides.push_back(rewriter.getIndexAttr(1));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto elementTensorType =
|
|
||||||
RankedTensorType::get(SmallVector<int64_t>(resultType.getRank(), 1), resultType.getElementType());
|
|
||||||
Value elementSlice =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, elementTensorType, source, sourceOffsets, unitSizes, unitStrides)
|
|
||||||
.getResult();
|
|
||||||
return tensor::InsertSliceOp::create(
|
|
||||||
rewriter, loc, elementSlice, accumulator, destinationOffsets, unitSizes, unitStrides)
|
|
||||||
.getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value lower = zeroIndices[dim];
|
|
||||||
Value upper =
|
|
||||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
|
|
||||||
Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
|
||||||
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
|
|
||||||
rewriter.setInsertionPointToStart(loop.getBody());
|
|
||||||
resultIndices.push_back(loop.getInductionVar());
|
|
||||||
Value updated = self(self, dim + 1, loop.getRegionIterArgs().front());
|
|
||||||
resultIndices.pop_back();
|
|
||||||
scf::YieldOp::create(rewriter, loc, updated);
|
|
||||||
rewriter.setInsertionPointAfter(loop);
|
|
||||||
return loop.getResult(0);
|
|
||||||
};
|
|
||||||
|
|
||||||
return buildLoopNest(buildLoopNest, 0, init);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value extractAxisSlice(
|
Value extractAxisSlice(
|
||||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||||
|
|||||||
@@ -108,13 +108,6 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
|||||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||||
|
|
||||||
mlir::Value materializeContiguousTensorSlice(mlir::Value source,
|
|
||||||
mlir::RankedTensorType resultType,
|
|
||||||
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
|
||||||
llvm::ArrayRef<mlir::OpFoldResult> strides,
|
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
|
||||||
mlir::Location loc);
|
|
||||||
|
|
||||||
mlir::Value extractAxisSlice(
|
mlir::Value extractAxisSlice(
|
||||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||||
|
|
||||||
|
|||||||
@@ -303,9 +303,11 @@ createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewri
|
|||||||
static Value extractDynamicGemmBColumn(
|
static Value extractDynamicGemmBColumn(
|
||||||
Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
|
||||||
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
|
||||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
||||||
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc);
|
Value columnSlice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides).getResult();
|
||||||
SmallVector<ReassociationIndices> collapseReassociation {
|
SmallVector<ReassociationIndices> collapseReassociation {
|
||||||
ReassociationIndices {0, 1}
|
ReassociationIndices {0, 1}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
static Value materializeTileTensor(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||||
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
|
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
|
||||||
@@ -319,7 +319,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
Value windowValue =
|
Value windowValue =
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
windowValue = materializeTileTensor(rewriter, loc, windowValue);
|
||||||
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
|
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -335,7 +335,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
Value scaleSlice = tensor::ExtractSliceOp::create(
|
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||||
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||||
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
scaleSlice = materializeTileTensor(rewriter, loc, scaleSlice);
|
||||||
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,10 +59,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
|||||||
ShapedType destinationType,
|
ShapedType destinationType,
|
||||||
IRMapping& mapper) {
|
IRMapping& mapper) {
|
||||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||||
SmallVector<int64_t> strides(destinationType.getRank(), 1);
|
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||||
ArrayRef<int64_t> shape = destinationType.getShape();
|
|
||||||
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
|
|
||||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
|
||||||
|
|
||||||
Value totalOffset;
|
Value totalOffset;
|
||||||
Location loc = insertSlice.getLoc();
|
Location loc = insertSlice.getLoc();
|
||||||
@@ -162,13 +159,14 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
|
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
|
||||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||||
|
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputBuffer.getType(),
|
outputBuffer.getType(),
|
||||||
|
zeroOffset,
|
||||||
|
zeroOffset,
|
||||||
outputBuffer,
|
outputBuffer,
|
||||||
newArg,
|
newArg,
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
getTensorSizeInBytesAttr(rewriter, newArg))
|
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
mapper.map(*oldArg, copied);
|
mapper.map(*oldArg, copied);
|
||||||
@@ -233,13 +231,14 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
}
|
}
|
||||||
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||||
|
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputBuffer.getType(),
|
outputBuffer.getType(),
|
||||||
|
zeroOffset,
|
||||||
|
zeroOffset,
|
||||||
outputBuffer,
|
outputBuffer,
|
||||||
clonedTensor,
|
clonedTensor,
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
mapper.map(toTensorOp.getResult(), copied);
|
mapper.map(toTensorOp.getResult(), copied);
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
#include "mlir/IR/ValueRange.h"
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
|
|
||||||
@@ -13,48 +11,6 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
|
||||||
/*
|
|
||||||
EXAMPLE RUN:
|
|
||||||
[1, 10, 3, 4] inputShape
|
|
||||||
[0, 2, 1, 3] offsets
|
|
||||||
|
|
||||||
acc = 1
|
|
||||||
---
|
|
||||||
ret = 3
|
|
||||||
acc = 4
|
|
||||||
---
|
|
||||||
ret = 3 + 4 * 1 = 7
|
|
||||||
acc = 12
|
|
||||||
---
|
|
||||||
ret = 7 + 12 * 2 = 31
|
|
||||||
acc = 120
|
|
||||||
---
|
|
||||||
ret = 31 + 120 * 0 = 31
|
|
||||||
acc = 120
|
|
||||||
*/
|
|
||||||
|
|
||||||
size_t returnValue = 0;
|
|
||||||
|
|
||||||
auto sliceOffsets = sliceOp.getStaticOffsets();
|
|
||||||
auto inputDimSizes = inputShape.getShape();
|
|
||||||
|
|
||||||
assert(sliceOffsets.size() == inputDimSizes.size());
|
|
||||||
|
|
||||||
size_t accumulatedDimensionSize = 1;
|
|
||||||
|
|
||||||
// Reverse iterate the two vectors
|
|
||||||
for (auto it : reverse(zip(sliceOffsets, inputDimSizes))) {
|
|
||||||
auto curSliceOffset = std::get<0>(it);
|
|
||||||
auto curInputDimSize = std::get<1>(it);
|
|
||||||
|
|
||||||
returnValue += accumulatedDimensionSize * curSliceOffset;
|
|
||||||
accumulatedDimensionSize *= curInputDimSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
return returnValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,20 +6,6 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
/**
|
|
||||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
|
||||||
* its static tensor input.
|
|
||||||
*
|
|
||||||
* The static offsets represent the starting position of the slice in each
|
|
||||||
* dimension, while the static tensor input gives its dimension size.
|
|
||||||
*
|
|
||||||
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
|
|
||||||
* calculated.
|
|
||||||
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
|
|
||||||
* \return The actual offset of the ExtractSliceOp.
|
|
||||||
*/
|
|
||||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
|
||||||
|
|
||||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
|
|||||||
@@ -83,18 +83,6 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter,
|
|||||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||||
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
|
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
|
||||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||||
|
|
||||||
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
|
||||||
return PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
tensorType,
|
|
||||||
outputBuffer,
|
|
||||||
zeroValue,
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
sizeAttr)
|
|
||||||
.getOutput();
|
|
||||||
|
|
||||||
return PimMemCopyHostToDevOp::create(
|
return PimMemCopyHostToDevOp::create(
|
||||||
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
|
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
|
||||||
.getOutput();
|
.getOutput();
|
||||||
|
|||||||
@@ -144,32 +144,6 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> {
|
|
||||||
let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$deviceTarget,
|
|
||||||
PimTensor:$hostSource,
|
|
||||||
I32Attr:$deviceTargetOffset,
|
|
||||||
I32Attr:$hostSourceOffset,
|
|
||||||
I32Attr:$size
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getDeviceTargetMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Copy a memory region from device memory into host memory";
|
let summary = "Copy a memory region from device memory into host memory";
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
|
||||||
@@ -10,8 +11,8 @@ using namespace bufferization;
|
|||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
|
||||||
return memrefValue;
|
return memrefValue;
|
||||||
|
|
||||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||||
@@ -30,7 +31,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
|||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value allocateContiguousMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||||
return memrefValue;
|
return memrefValue;
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,9 @@
|
|||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||||
mlir::Value allocateContiguousMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
mlir::Value
|
||||||
|
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter,
|
llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter,
|
||||||
mlir::Value value,
|
mlir::Value value,
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ add_pim_library(OMPimBufferization
|
|||||||
PimBufferizationPass.cpp
|
PimBufferizationPass.cpp
|
||||||
BufferizationUtils.hpp
|
BufferizationUtils.hpp
|
||||||
BufferizationUtils.cpp
|
BufferizationUtils.cpp
|
||||||
|
ContiguityPatterns.hpp
|
||||||
|
ContiguityPatterns.cpp
|
||||||
OpBufferizationInterfaces.hpp
|
OpBufferizationInterfaces.hpp
|
||||||
OpBufferizationInterfaces.cpp
|
OpBufferizationInterfaces.cpp
|
||||||
Common.hpp
|
Common.hpp
|
||||||
|
|||||||
@@ -0,0 +1,343 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
|
||||||
|
#include "ContiguityPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir::pim {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) {
|
||||||
|
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
||||||
|
if (extraOffset == 0)
|
||||||
|
return baseOffset;
|
||||||
|
|
||||||
|
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||||
|
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
|
assert(integerAttr && "expected integer offset attribute");
|
||||||
|
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto value = cast<Value>(baseOffset);
|
||||||
|
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
|
||||||
|
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
||||||
|
ArrayRef<int64_t> outerIndices,
|
||||||
|
Location loc,
|
||||||
|
PatternRewriter& rewriter) {
|
||||||
|
SmallVector<OpFoldResult> chunkOffsets;
|
||||||
|
SmallVector<OpFoldResult> chunkSizes;
|
||||||
|
SmallVector<OpFoldResult> chunkStrides;
|
||||||
|
chunkOffsets.reserve(info.offsets.size());
|
||||||
|
chunkSizes.reserve(info.sizes.size());
|
||||||
|
chunkStrides.reserve(info.strides.size());
|
||||||
|
|
||||||
|
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||||
|
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
||||||
|
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
||||||
|
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
||||||
|
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
|
||||||
|
ArrayRef<int64_t> shape,
|
||||||
|
ArrayRef<int64_t> strides,
|
||||||
|
PatternRewriter& rewriter) {
|
||||||
|
SmallVector<Value> indices;
|
||||||
|
indices.reserve(shape.size());
|
||||||
|
|
||||||
|
Value remaining = linearIndex;
|
||||||
|
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
||||||
|
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
|
||||||
|
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||||
|
indices.push_back(index);
|
||||||
|
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
||||||
|
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||||
|
auto integerAttr = cast<IntegerAttr>(attr);
|
||||||
|
if (integerAttr.getInt() == 0)
|
||||||
|
return extraOffset;
|
||||||
|
|
||||||
|
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
|
||||||
|
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto value = cast<Value>(baseOffset);
|
||||||
|
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
||||||
|
ArrayRef<Value> outerIndices,
|
||||||
|
Location loc,
|
||||||
|
PatternRewriter& rewriter) {
|
||||||
|
SmallVector<OpFoldResult> chunkOffsets;
|
||||||
|
SmallVector<OpFoldResult> chunkSizes;
|
||||||
|
SmallVector<OpFoldResult> chunkStrides;
|
||||||
|
chunkOffsets.reserve(info.offsets.size());
|
||||||
|
chunkSizes.reserve(info.sizes.size());
|
||||||
|
chunkStrides.reserve(info.strides.size());
|
||||||
|
|
||||||
|
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||||
|
if (dim + 1 < info.sizes.size()) {
|
||||||
|
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
||||||
|
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
||||||
|
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
chunkOffsets.push_back(info.offsets[dim]);
|
||||||
|
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
||||||
|
}
|
||||||
|
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildContiguousChunk(
|
||||||
|
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
|
||||||
|
SmallVector<OpFoldResult> chunkOffsets;
|
||||||
|
SmallVector<OpFoldResult> chunkSizes;
|
||||||
|
SmallVector<OpFoldResult> chunkStrides;
|
||||||
|
chunkOffsets.reserve(copyShape.size());
|
||||||
|
chunkSizes.reserve(copyShape.size());
|
||||||
|
chunkStrides.reserve(copyShape.size());
|
||||||
|
|
||||||
|
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
||||||
|
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
||||||
|
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
||||||
|
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename CopyOp, typename CreateCopyOp>
|
||||||
|
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||||
|
Value dst,
|
||||||
|
Value src,
|
||||||
|
int64_t dstOffset,
|
||||||
|
int64_t srcOffset,
|
||||||
|
int64_t size,
|
||||||
|
bool allowLoopRewrite,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
CreateCopyOp createCopyOp) {
|
||||||
|
auto srcSubview = getStaticSubviewInfo(src);
|
||||||
|
auto dstSubview = getStaticSubviewInfo(dst);
|
||||||
|
const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview);
|
||||||
|
const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview);
|
||||||
|
if (!splitSrc && !splitDst)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
||||||
|
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
||||||
|
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (sourceType.getElementType() != dstType.getElementType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||||
|
return failure();
|
||||||
|
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||||
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!hasByteSizedElementType(sourceType.getElementType()))
|
||||||
|
return failure();
|
||||||
|
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||||
|
|
||||||
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||||
|
if (size != totalBytes)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||||
|
if (sliceBytes <= 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||||
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||||
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||||
|
const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape);
|
||||||
|
const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape);
|
||||||
|
|
||||||
|
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
||||||
|
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
||||||
|
&& dstType.getRank() == static_cast<int64_t>(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape)
|
||||||
|
&& (splitDst || dstShapeMatchesCopyShape)) {
|
||||||
|
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
||||||
|
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
|
||||||
|
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
|
||||||
|
|
||||||
|
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
|
SmallVector<Value> outerIndices =
|
||||||
|
outerShape.empty() ? SmallVector<Value> {}
|
||||||
|
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
||||||
|
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||||
|
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||||
|
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||||
|
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||||
|
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(copyOp);
|
||||||
|
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||||
|
SmallVector<int64_t> outerIndices =
|
||||||
|
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||||
|
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
||||||
|
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
||||||
|
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
||||||
|
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
||||||
|
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto status = rewriteSubviewCopyLikeOp(
|
||||||
|
copyOp,
|
||||||
|
copyOp.getTarget(),
|
||||||
|
copyOp.getSource(),
|
||||||
|
copyOp.getTargetOffset(),
|
||||||
|
copyOp.getSourceOffset(),
|
||||||
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/true,
|
||||||
|
rewriter,
|
||||||
|
[&](
|
||||||
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
resultType,
|
||||||
|
dst,
|
||||||
|
src,
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
|
});
|
||||||
|
if (failed(status))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOp(copyOp, copyOp.getTarget());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NormalizeHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
|
||||||
|
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
|
||||||
|
if (failed(dstOffset) || failed(srcOffset))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto status = rewriteSubviewCopyLikeOp(
|
||||||
|
copyOp,
|
||||||
|
copyOp.getDeviceTarget(),
|
||||||
|
copyOp.getHostSource(),
|
||||||
|
*dstOffset,
|
||||||
|
*srcOffset,
|
||||||
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/true,
|
||||||
|
rewriter,
|
||||||
|
[&](
|
||||||
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
|
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
||||||
|
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
||||||
|
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
resultType,
|
||||||
|
dstOffsetValue,
|
||||||
|
srcOffsetValue,
|
||||||
|
dst,
|
||||||
|
src,
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
|
});
|
||||||
|
if (failed(status))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NormalizeHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
|
||||||
|
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
|
||||||
|
if (failed(dstOffset) || failed(srcOffset))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto status = rewriteSubviewCopyLikeOp(
|
||||||
|
copyOp,
|
||||||
|
copyOp.getHostTarget(),
|
||||||
|
copyOp.getDeviceSource(),
|
||||||
|
*dstOffset,
|
||||||
|
*srcOffset,
|
||||||
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/false,
|
||||||
|
rewriter,
|
||||||
|
[&](
|
||||||
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
|
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
||||||
|
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
||||||
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
resultType,
|
||||||
|
dstOffsetValue,
|
||||||
|
srcOffsetValue,
|
||||||
|
dst,
|
||||||
|
src,
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
|
});
|
||||||
|
if (failed(status))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
|
||||||
|
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
|
||||||
|
patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir::pim
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
|
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir::pim
|
||||||
@@ -47,32 +47,6 @@ struct MemCopyHostToDevOpInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MemCopyHostToDevBatchOpInterface
|
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevBatchOpInterface, PimMemCopyHostToDevBatchOp> {
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
|
|
||||||
auto deviceTargetOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
|
|
||||||
if (failed(deviceTargetOpt))
|
|
||||||
return failure();
|
|
||||||
auto hostSourceOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
|
|
||||||
if (failed(hostSourceOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevBatchOp>(rewriter,
|
|
||||||
memCopyHostToDevOp,
|
|
||||||
deviceTargetOpt->getType(),
|
|
||||||
*deviceTargetOpt,
|
|
||||||
*hostSourceOpt,
|
|
||||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
|
||||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
|
||||||
memCopyHostToDevOp.getSizeAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MemCopyDevToHostOpInterface
|
struct MemCopyDevToHostOpInterface
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
@@ -151,8 +125,9 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(
|
replaceOpWithNewBufferizedOp<PimReceiveOp>(
|
||||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
|
rewriter, op, contiguousOutput.getType(), contiguousOutput, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -173,15 +148,16 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
|||||||
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter));
|
inputs.push_back(materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
replaceOpWithNewBufferizedOp<PimConcatOp>(
|
replaceOpWithNewBufferizedOp<PimConcatOp>(
|
||||||
rewriter, op, outputBufferOpt->getType(), concatOp.getAxisAttr(), ValueRange(inputs), *outputBufferOpt);
|
rewriter, op, contiguousOutput.getType(), concatOp.getAxisAttr(), ValueRange(inputs), contiguousOutput);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -206,7 +182,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
|||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
|
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
|
||||||
op,
|
op,
|
||||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||||
sendOp.getSizeAttr(),
|
sendOp.getSizeAttr(),
|
||||||
sendOp.getTargetCoreId());
|
sendOp.getTargetCoreId());
|
||||||
return success();
|
return success();
|
||||||
@@ -431,8 +407,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
replaceOpWithNewBufferizedOp<PimTransposeOp>(
|
||||||
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput);
|
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput);
|
||||||
@@ -475,8 +451,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||||
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput);
|
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput);
|
||||||
@@ -514,9 +490,9 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||||
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<OpTy>(
|
replaceOpWithNewBufferizedOp<OpTy>(
|
||||||
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||||
@@ -547,9 +523,9 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||||
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
|
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
|
||||||
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||||
@@ -583,8 +559,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||||
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput);
|
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput);
|
||||||
return success();
|
return success();
|
||||||
@@ -599,7 +575,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
|||||||
PimSendOp::attachInterface<SendOpInterface>(*ctx);
|
PimSendOp::attachInterface<SendOpInterface>(*ctx);
|
||||||
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
|
||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
|
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
|
||||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||||
|
|||||||
@@ -6,14 +6,14 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
@@ -40,6 +40,7 @@ struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<Mo
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
|
||||||
|
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -84,6 +85,20 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// After this pass, executable PIM ops must only use contiguous/addressable memrefs.
|
||||||
|
// Later PIM codegen passes may verify this invariant but must not repair it.
|
||||||
|
RewritePatternSet contiguityPatterns(ctx);
|
||||||
|
populatePimContiguityNormalizationPatterns(contiguityPatterns);
|
||||||
|
if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) {
|
||||||
|
moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (failed(verifyContiguousRuntimeOperands(moduleOp))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
annotateWeightsMemrefs(moduleOp, funcOp);
|
annotateWeightsMemrefs(moduleOp, funcOp);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
@@ -108,6 +123,75 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
|||||||
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
|
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp moduleOp) const {
|
||||||
|
bool hasFailure = false;
|
||||||
|
moduleOp.walk([&](Operation* op) {
|
||||||
|
auto verifyOperand = [&](Value operand, unsigned operandIndex) {
|
||||||
|
if (!isa<BaseMemRefType>(operand.getType()))
|
||||||
|
return;
|
||||||
|
if (succeeded(resolveContiguousAddress(operand)) || succeeded(compileContiguousAddressExpr(operand)))
|
||||||
|
return;
|
||||||
|
op->emitOpError() << "operand #" << operandIndex
|
||||||
|
<< " is not backed by contiguous addressable storage after PIM bufferization";
|
||||||
|
hasFailure = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
|
||||||
|
verifyOperand(memCopyOp.getTarget(), 0);
|
||||||
|
verifyOperand(memCopyOp.getSource(), 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
|
||||||
|
verifyOperand(loadOp.getDeviceTarget(), 2);
|
||||||
|
verifyOperand(loadOp.getHostSource(), 3);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
|
||||||
|
verifyOperand(storeOp.getHostTarget(), 2);
|
||||||
|
verifyOperand(storeOp.getDeviceSource(), 3);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (auto sendOp = dyn_cast<PimSendOp>(op)) {
|
||||||
|
verifyOperand(sendOp.getInput(), 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (auto receiveOp = dyn_cast<PimReceiveOp>(op)) {
|
||||||
|
verifyOperand(receiveOp.getOutputBuffer(), 0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (auto concatOp = dyn_cast<PimConcatOp>(op)) {
|
||||||
|
verifyOperand(concatOp.getOutputBuffer(), 0);
|
||||||
|
for (auto inputAndIndex : llvm::enumerate(concatOp.getInputs()))
|
||||||
|
verifyOperand(inputAndIndex.value(), inputAndIndex.index() + 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (isa<PimTransposeOp,
|
||||||
|
PimVMMOp,
|
||||||
|
PimVVAddOp,
|
||||||
|
PimVVSubOp,
|
||||||
|
PimVVMulOp,
|
||||||
|
PimVVMaxOp,
|
||||||
|
PimVVDMulOp,
|
||||||
|
PimVAvgOp,
|
||||||
|
PimVReluOp,
|
||||||
|
PimVTanhOp,
|
||||||
|
PimVSigmOp,
|
||||||
|
PimVSoftmaxOp>(op)) {
|
||||||
|
for (auto operandAndIndex : llvm::enumerate(op->getOperands())) {
|
||||||
|
if (auto vmmOp = dyn_cast<PimVMMOp>(op); vmmOp && operandAndIndex.index() == 0)
|
||||||
|
continue;
|
||||||
|
verifyOperand(operandAndIndex.value(), operandAndIndex.index());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (hasFailure) {
|
||||||
|
moduleOp.emitError("PIM bufferization must fully normalize executable runtime operand contiguity before codegen");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ using namespace onnx_mlir::compact_asm;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// This pass assumes bufferization has already normalized executable PIM
|
||||||
|
// operands. It only reuses compatible local allocations with non-overlapping
|
||||||
|
// lifetimes; it does not repair memory contiguity.
|
||||||
|
|
||||||
struct CoalescingReportRow {
|
struct CoalescingReportRow {
|
||||||
uint64_t numCandidates = 0;
|
uint64_t numCandidates = 0;
|
||||||
uint64_t numSkipped = 0;
|
uint64_t numSkipped = 0;
|
||||||
|
|||||||
@@ -1,349 +1,12 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
||||||
|
|
||||||
#include "../Common.hpp"
|
#include "../Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool isSubviewContiguous(const StaticSubviewInfo& info) {
|
|
||||||
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()),
|
|
||||||
llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend()));
|
|
||||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, dimension] = sizeAndShape;
|
|
||||||
return size != dimension;
|
|
||||||
});
|
|
||||||
if (firstDifferentSize == sizesAndShape.end())
|
|
||||||
return true;
|
|
||||||
|
|
||||||
++firstDifferentSize;
|
|
||||||
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
|
|
||||||
auto [size, _dimension] = sizeAndShape;
|
|
||||||
return size == 1;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
|
|
||||||
if (extraOffset == 0)
|
|
||||||
return baseOffset;
|
|
||||||
|
|
||||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
|
||||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
|
||||||
assert(integerAttr && "expected integer offset attribute");
|
|
||||||
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto value = cast<Value>(baseOffset);
|
|
||||||
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
|
|
||||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
|
||||||
ArrayRef<int64_t> outerIndices,
|
|
||||||
Location loc,
|
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> chunkOffsets;
|
|
||||||
SmallVector<OpFoldResult> chunkSizes;
|
|
||||||
SmallVector<OpFoldResult> chunkStrides;
|
|
||||||
chunkOffsets.reserve(info.offsets.size());
|
|
||||||
chunkSizes.reserve(info.sizes.size());
|
|
||||||
chunkStrides.reserve(info.strides.size());
|
|
||||||
|
|
||||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
|
||||||
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
|
|
||||||
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
|
|
||||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
|
|
||||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
|
||||||
}
|
|
||||||
|
|
||||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
|
|
||||||
ArrayRef<int64_t> shape,
|
|
||||||
ArrayRef<int64_t> strides,
|
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
SmallVector<Value> indices;
|
|
||||||
indices.reserve(shape.size());
|
|
||||||
|
|
||||||
Value remaining = linearIndex;
|
|
||||||
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
|
||||||
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
|
|
||||||
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
|
||||||
indices.push_back(index);
|
|
||||||
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
return indices;
|
|
||||||
}
|
|
||||||
|
|
||||||
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
|
||||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
|
||||||
auto integerAttr = cast<IntegerAttr>(attr);
|
|
||||||
if (integerAttr.getInt() == 0)
|
|
||||||
return extraOffset;
|
|
||||||
|
|
||||||
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
|
|
||||||
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto value = cast<Value>(baseOffset);
|
|
||||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
|
||||||
ArrayRef<Value> outerIndices,
|
|
||||||
Location loc,
|
|
||||||
PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> chunkOffsets;
|
|
||||||
SmallVector<OpFoldResult> chunkSizes;
|
|
||||||
SmallVector<OpFoldResult> chunkStrides;
|
|
||||||
chunkOffsets.reserve(info.offsets.size());
|
|
||||||
chunkSizes.reserve(info.sizes.size());
|
|
||||||
chunkStrides.reserve(info.strides.size());
|
|
||||||
|
|
||||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
|
||||||
if (dim + 1 < info.sizes.size()) {
|
|
||||||
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
|
||||||
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
|
||||||
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
chunkOffsets.push_back(info.offsets[dim]);
|
|
||||||
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
|
||||||
}
|
|
||||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
|
||||||
}
|
|
||||||
|
|
||||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value buildContiguousChunk(
|
|
||||||
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
|
|
||||||
SmallVector<OpFoldResult> chunkOffsets;
|
|
||||||
SmallVector<OpFoldResult> chunkSizes;
|
|
||||||
SmallVector<OpFoldResult> chunkStrides;
|
|
||||||
chunkOffsets.reserve(copyShape.size());
|
|
||||||
chunkSizes.reserve(copyShape.size());
|
|
||||||
chunkStrides.reserve(copyShape.size());
|
|
||||||
|
|
||||||
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
|
||||||
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
|
||||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
|
||||||
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
|
||||||
}
|
|
||||||
|
|
||||||
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename CopyOp, typename CreateCopyOp>
|
|
||||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
|
||||||
Value dst,
|
|
||||||
Value src,
|
|
||||||
int64_t dstOffset,
|
|
||||||
int64_t srcOffset,
|
|
||||||
int64_t size,
|
|
||||||
bool allowLoopRewrite,
|
|
||||||
PatternRewriter& rewriter,
|
|
||||||
CreateCopyOp createCopyOp) {
|
|
||||||
auto srcSubview = getStaticSubviewInfo(src);
|
|
||||||
auto dstSubview = getStaticSubviewInfo(dst);
|
|
||||||
const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview);
|
|
||||||
const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview);
|
|
||||||
if (!splitSrc && !splitDst)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(src.getType());
|
|
||||||
auto dstType = dyn_cast<MemRefType>(dst.getType());
|
|
||||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
if (sourceType.getElementType() != dstType.getElementType())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
|
|
||||||
return failure();
|
|
||||||
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
|
||||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!hasByteSizedElementType(sourceType.getElementType()))
|
|
||||||
return failure();
|
|
||||||
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
|
||||||
|
|
||||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
|
||||||
if (size != totalBytes)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
|
||||||
if (sliceBytes <= 0)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
|
||||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
|
||||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
|
||||||
|
|
||||||
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
|
||||||
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
|
||||||
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
|
|
||||||
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
|
|
||||||
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
|
|
||||||
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
|
|
||||||
|
|
||||||
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
|
||||||
rewriter.setInsertionPointToStart(loop.getBody());
|
|
||||||
|
|
||||||
SmallVector<Value> outerIndices =
|
|
||||||
outerShape.empty() ? SmallVector<Value> {}
|
|
||||||
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
|
||||||
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
|
||||||
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
|
||||||
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
|
||||||
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
|
||||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(copyOp);
|
|
||||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
|
||||||
SmallVector<int64_t> outerIndices =
|
|
||||||
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
|
||||||
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
|
|
||||||
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
|
|
||||||
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
|
|
||||||
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
|
|
||||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Splits core copies through subviews into contiguous copy chunks for codegen.
|
|
||||||
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
|
||||||
if (!copyOp->getParentOfType<pim::PimCoreOp>())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto status = rewriteSubviewCopyLikeOp(
|
|
||||||
copyOp,
|
|
||||||
copyOp.getTarget(),
|
|
||||||
copyOp.getSource(),
|
|
||||||
copyOp.getTargetOffset(),
|
|
||||||
copyOp.getSourceOffset(),
|
|
||||||
copyOp.getSize(),
|
|
||||||
/*allowLoopRewrite=*/true,
|
|
||||||
rewriter,
|
|
||||||
[&](
|
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
|
||||||
copyOp.getLoc(),
|
|
||||||
resultType,
|
|
||||||
dst,
|
|
||||||
src,
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
||||||
});
|
|
||||||
if (failed(status))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getTarget());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Splits host-to-device subview loads into contiguous copy chunks.
|
|
||||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
|
||||||
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
|
|
||||||
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
|
|
||||||
if (failed(dstOffset) || failed(srcOffset))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto status = rewriteSubviewCopyLikeOp(
|
|
||||||
copyOp,
|
|
||||||
copyOp.getDeviceTarget(),
|
|
||||||
copyOp.getHostSource(),
|
|
||||||
*dstOffset,
|
|
||||||
*srcOffset,
|
|
||||||
copyOp.getSize(),
|
|
||||||
/*allowLoopRewrite=*/true,
|
|
||||||
rewriter,
|
|
||||||
[&](
|
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
|
||||||
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
|
||||||
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
|
||||||
pim::PimMemCopyHostToDevOp::create(rewriter,
|
|
||||||
copyOp.getLoc(),
|
|
||||||
resultType,
|
|
||||||
dstOffsetValue,
|
|
||||||
srcOffsetValue,
|
|
||||||
dst,
|
|
||||||
src,
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
||||||
});
|
|
||||||
if (failed(status))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Splits device-to-host subview stores into contiguous copy chunks.
|
|
||||||
struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
|
||||||
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
|
|
||||||
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
|
|
||||||
if (failed(dstOffset) || failed(srcOffset))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto status = rewriteSubviewCopyLikeOp(
|
|
||||||
copyOp,
|
|
||||||
copyOp.getHostTarget(),
|
|
||||||
copyOp.getDeviceSource(),
|
|
||||||
*dstOffset,
|
|
||||||
*srcOffset,
|
|
||||||
copyOp.getSize(),
|
|
||||||
/*allowLoopRewrite=*/false,
|
|
||||||
rewriter,
|
|
||||||
[&](
|
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
|
||||||
Value dstOffset = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
|
|
||||||
Value srcOffset = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
|
|
||||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
|
||||||
copyOp.getLoc(),
|
|
||||||
resultType,
|
|
||||||
dstOffset,
|
|
||||||
srcOffset,
|
|
||||||
dst,
|
|
||||||
src,
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
|
||||||
});
|
|
||||||
if (failed(status))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Folds constant subviews used as core weights into standalone globals.
|
// Folds constant subviews used as core weights into standalone globals.
|
||||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
@@ -392,10 +55,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
||||||
patterns.add<RewriteCoreSubviewCopyPattern,
|
patterns.add<FoldConstantCoreSubviewPattern>(patterns.getContext());
|
||||||
RewriteHostSubviewLoadPattern,
|
|
||||||
RewriteHostSubviewStorePattern,
|
|
||||||
FoldConstantCoreSubviewPattern>(patterns.getContext());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -7,12 +7,9 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/MathExtras.h"
|
#include "llvm/Support/MathExtras.h"
|
||||||
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
@@ -85,31 +82,18 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
|
|||||||
if (contiguousType != originalType)
|
if (contiguousType != originalType)
|
||||||
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
|
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
|
||||||
|
|
||||||
Value copiedValue;
|
Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
|
||||||
if constexpr (std::is_same_v<CoreOpTy, pim::PimCoreBatchOp>) {
|
Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
|
||||||
copiedValue = pim::PimMemCopyHostToDevBatchOp::create(
|
Value copiedValue =
|
||||||
rewriter,
|
pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
originalType,
|
originalType,
|
||||||
deviceDst,
|
zeroOffset,
|
||||||
getGlobalOp.getResult(),
|
hostOffset,
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
|
||||||
.getOutput();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
copiedValue = pim::PimMemCopyHostToDevOp::create(
|
|
||||||
rewriter,
|
|
||||||
op->getLoc(),
|
|
||||||
originalType,
|
|
||||||
getOrCreateIndexConstant(constantFolder, op, 0),
|
|
||||||
getOrCreateIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset)),
|
|
||||||
deviceDst,
|
deviceDst,
|
||||||
getGlobalOp.getResult(),
|
getGlobalOp.getResult(),
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
|
||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
|
||||||
|
|
||||||
cachedByType[originalType] = copiedValue;
|
cachedByType[originalType] = copiedValue;
|
||||||
operand.set(copiedValue);
|
operand.set(copiedValue);
|
||||||
|
|||||||
@@ -32,51 +32,30 @@ static bool isAddressOnlyHostOp(Operation* op) {
|
|||||||
memref::CopyOp>(op);
|
memref::CopyOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Looser than isCodegenAddressableValue: follows view ops without requiring contiguity.
|
|
||||||
// Used for memref.copy operands which may be non-contiguous subviews.
|
|
||||||
static bool isBaseAddressableValue(Value value) {
|
|
||||||
while (true) {
|
|
||||||
if (isa<BlockArgument>(value))
|
|
||||||
return true;
|
|
||||||
Operation* defOp = value.getDefiningOp();
|
|
||||||
if (!defOp)
|
|
||||||
return false;
|
|
||||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
|
|
||||||
return true;
|
|
||||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
|
||||||
value = subview.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
|
||||||
value = cast.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
|
||||||
value = collapse.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
|
||||||
value = expand.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isCodegenAddressableValue(Value value) {
|
static bool isCodegenAddressableValue(Value value) {
|
||||||
auto resolvedAddress = resolveContiguousAddress(value);
|
auto resolvedAddress = resolveContiguousAddress(value);
|
||||||
if (failed(resolvedAddress))
|
if (succeeded(resolvedAddress))
|
||||||
return false;
|
|
||||||
return isa<BlockArgument>(resolvedAddress->base)
|
return isa<BlockArgument>(resolvedAddress->base)
|
||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
|
|
||||||
|
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||||
|
if (failed(compiledAddress))
|
||||||
|
return false;
|
||||||
|
return isa<BlockArgument>(compiledAddress->base)
|
||||||
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||||
if (failed(resolvedAddress))
|
if (succeeded(resolvedAddress))
|
||||||
return false;
|
|
||||||
return isa<BlockArgument>(resolvedAddress->base)
|
return isa<BlockArgument>(resolvedAddress->base)
|
||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
|
|
||||||
|
auto compiledAddress = compileContiguousAddressExpr(value);
|
||||||
|
if (failed(compiledAddress))
|
||||||
|
return false;
|
||||||
|
return isa<BlockArgument>(compiledAddress->base)
|
||||||
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isConstantGlobalView(Value value) {
|
static bool isConstantGlobalView(Value value) {
|
||||||
@@ -138,7 +117,6 @@ static bool isCoreWeightBlockArgument(Value value) {
|
|||||||
|
|
||||||
static bool isSupportedCoreInstructionOp(Operation* op) {
|
static bool isSupportedCoreInstructionOp(Operation* op) {
|
||||||
return isa<pim::PimMemCopyHostToDevOp,
|
return isa<pim::PimMemCopyHostToDevOp,
|
||||||
pim::PimMemCopyHostToDevBatchOp,
|
|
||||||
pim::PimMemCopyDevToHostOp,
|
pim::PimMemCopyDevToHostOp,
|
||||||
pim::PimMemCopyOp,
|
pim::PimMemCopyOp,
|
||||||
pim::PimReceiveOp,
|
pim::PimReceiveOp,
|
||||||
@@ -159,27 +137,6 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
|
|||||||
memref::GetGlobalOp>(op);
|
memref::GetGlobalOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyBatchOpSemantics(Operation& op,
|
|
||||||
const StaticValueKnowledge& knowledge,
|
|
||||||
pim::CappedDiagnosticReporter& diagnostics) {
|
|
||||||
bool hasFailure = false;
|
|
||||||
auto reportFailure = [&](auto emitDiagnostic) {
|
|
||||||
diagnostics.report(&op, [&](Operation* illegalOp) { emitDiagnostic(illegalOp); });
|
|
||||||
hasFailure = true;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (auto memcpHdBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
|
||||||
if (!isCodegenAddressableValue(memcpHdBatchOp.getHostSource(), knowledge)) {
|
|
||||||
reportFailure([](Operation* illegalOp) {
|
|
||||||
illegalOp->emitOpError("host operand #1 is not backed by contiguous addressable storage");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return success(!hasFailure);
|
|
||||||
}
|
|
||||||
|
|
||||||
return success(!hasFailure);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
|
||||||
|
|
||||||
@@ -343,7 +300,7 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||||
if (failed(resolvedAddress)) {
|
if (failed(resolvedAddress) && failed(compileContiguousAddressExpr(operand))) {
|
||||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||||
<< " is not backed by contiguous addressable storage";
|
<< " is not backed by contiguous addressable storage";
|
||||||
@@ -363,7 +320,9 @@ private:
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
Value addressBase =
|
||||||
|
succeeded(resolvedAddress) ? resolvedAddress->base : compileContiguousAddressExpr(operand)->base;
|
||||||
|
if (!isa<memref::AllocOp>(addressBase.getDefiningOp())) {
|
||||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||||
<< " must be backed by device-local memory; materialize host values with "
|
<< " must be backed by device-local memory; materialize host values with "
|
||||||
@@ -392,9 +351,6 @@ private:
|
|||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(verifyBatchOpSemantics(op, knowledge, diagnostics)))
|
|
||||||
hasFailure = true;
|
|
||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -409,7 +365,7 @@ private:
|
|||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||||
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
|
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
|
||||||
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
||||||
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
|
if (!isMemRefBaseAddressableValue(copyOp.getSource()) || !isMemRefBaseAddressableValue(copyOp.getTarget())) {
|
||||||
diagnostics.report(op, [](Operation* illegalOp) {
|
diagnostics.report(op, [](Operation* illegalOp) {
|
||||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||||
});
|
});
|
||||||
@@ -432,7 +388,7 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
if (isBaseAddressableValue(source))
|
if (isMemRefBaseAddressableValue(source))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
diagnostics.report(op, [](Operation* illegalOp) {
|
diagnostics.report(op, [](Operation* illegalOp) {
|
||||||
|
|||||||
Reference in New Issue
Block a user