centralize logic for materializing contiguous memory into bufferization
Validate Operations / validate-operations (push) Has been cancelled

fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
This commit is contained in:
NiccoloN
2026-05-30 15:54:24 +02:00
parent 2d5b03c08f
commit cf93caecd5
29 changed files with 642 additions and 823 deletions
+22 -15
View File
@@ -616,31 +616,38 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure(); return mlir::failure();
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
llvm::SmallVector<int64_t> staticSizes; llvm::SmallVector<int64_t> staticSizes;
staticSizes.reserve(subviewOp.getMixedSizes().size()); staticSizes.reserve(subviewOp.getMixedSizes().size());
llvm::SmallVector<int64_t> staticStrides; llvm::SmallVector<int64_t> staticStrides;
staticStrides.reserve(subviewOp.getMixedStrides().size()); staticStrides.reserve(subviewOp.getMixedStrides().size());
bool allStatic = true; llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
bool hasOnlyStaticOffsets = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt()); staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else else
allStatic = false; hasOnlyStaticOffsets = false;
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size)) auto attr = mlir::dyn_cast<mlir::Attribute>(size);
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt()); if (!attr)
else return mlir::failure();
allStatic = false; staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) }
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride)) for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt()); auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
else if (!attr)
allStatic = false; return mlir::failure();
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
if (allStatic) { if (!isContiguousSubviewWithDynamicOffsets(
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
return mlir::failure();
}
if (hasOnlyStaticOffsets) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides)) if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure(); return mlir::failure();
-2
View File
@@ -20,8 +20,6 @@ llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds,
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) { bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op)) if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 3; return operandIndex == 3;
if (mlir::isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op)) if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2; return operandIndex == 2;
return false; return false;
+52
View File
@@ -111,4 +111,56 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
return true; return true;
} }
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides) {
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|| sourceShape.size() != staticStrides.size()) {
return false;
}
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
return false;
auto reversedTriples =
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
auto [_sourceDim, offset, _size] = triple;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
return true;
});
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
if (size > sourceDim - staticOffset)
return false;
}
++firstNonZeroOrDynamicOffset;
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
if (std::get<2>(*it) != 1)
return false;
}
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
auto [sourceDim, size] = pair;
return size != sourceDim;
});
if (firstDifferentSize != reversedSizes.end()) {
++firstDifferentSize;
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
if (std::get<1>(*it) != 1)
return false;
}
return true;
}
} // namespace onnx_mlir } // namespace onnx_mlir
+6
View File
@@ -1,6 +1,7 @@
#pragma once #pragma once
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
@@ -30,4 +31,9 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides); llvm::ArrayRef<int64_t> strides);
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides);
} // namespace onnx_mlir } // namespace onnx_mlir
+22
View File
@@ -31,6 +31,19 @@ Value stripMemRefViewOps(Value value) {
} }
} }
Value stripMemRefAddressingOps(Value value) {
while (true) {
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
value = subviewOp.getSource();
continue;
}
Value strippedValue = stripMemRefViewOps(value);
if (strippedValue == value)
return value;
value = strippedValue;
}
}
bool hasAllStaticSubviewParts(memref::SubViewOp subview) { bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
@@ -81,4 +94,13 @@ FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo&
return staticOffsets; return staticOffsets;
} }
bool isMemRefBaseAddressableValue(Value value) {
value = stripMemRefAddressingOps(value);
if (isa<BlockArgument>(value))
return true;
Operation* defOp = value.getDefiningOp();
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
}
} // namespace onnx_mlir } // namespace onnx_mlir
+4
View File
@@ -20,6 +20,8 @@ mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value); mlir::Value stripMemRefViewOps(mlir::Value value);
mlir::Value stripMemRefAddressingOps(mlir::Value value);
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview); bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value); llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
@@ -27,4 +29,6 @@ llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic. /// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info); llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
bool isMemRefBaseAddressableValue(mlir::Value value);
} // namespace onnx_mlir } // namespace onnx_mlir
+1 -23
View File
@@ -47,28 +47,6 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs)); return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
} }
mlir::Value stripWeightViewOps(mlir::Value value) {
while (true) {
if (auto subviewOp = value.getDefiningOp<mlir::memref::SubViewOp>()) {
value = subviewOp.getSource();
continue;
}
if (auto castOp = value.getDefiningOp<mlir::memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<mlir::memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<mlir::memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
template <typename VMMOpTy, typename ParentOpTy> template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex); auto weightArg = parentOp.getWeightArgument(weightIndex);
@@ -159,7 +137,7 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
} }
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) { std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
weight = stripWeightViewOps(weight); weight = stripMemRefAddressingOps(weight);
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) { if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
+6 -21
View File
@@ -479,16 +479,6 @@ void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticVa
loadOp.getSize()); loadOp.getSize());
} }
void PimCodeGen::codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp,
const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("ld",
addressOf(loadOp.getDeviceTarget(), knowledge),
loadOp.getDeviceTargetOffset(),
addressOf(loadOp.getHostSource(), knowledge),
loadOp.getHostSourceOffset(),
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
auto hostTargetOffset = indexOf(storeOp.getHostTargetOffset(), knowledge); auto hostTargetOffset = indexOf(storeOp.getHostTargetOffset(), knowledge);
auto deviceSourceOffset = indexOf(storeOp.getDeviceSourceOffset(), knowledge); auto deviceSourceOffset = indexOf(storeOp.getDeviceSourceOffset(), knowledge);
@@ -848,7 +838,6 @@ public:
enum class CompiledCoreOpKind : uint8_t { enum class CompiledCoreOpKind : uint8_t {
Load, Load,
LoadBatch,
Store, Store,
Lmv, Lmv,
Receive, Receive,
@@ -887,8 +876,6 @@ struct CompiledCoreNode {
static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) { static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
if (isa<pim::PimMemCopyHostToDevOp>(op)) if (isa<pim::PimMemCopyHostToDevOp>(op))
return CompiledCoreOpKind::Load; return CompiledCoreOpKind::Load;
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
return CompiledCoreOpKind::LoadBatch;
if (isa<pim::PimMemCopyDevToHostOp>(op)) if (isa<pim::PimMemCopyDevToHostOp>(op))
return CompiledCoreOpKind::Store; return CompiledCoreOpKind::Store;
if (isa<pim::PimMemCopyOp>(op)) if (isa<pim::PimMemCopyOp>(op))
@@ -1027,9 +1014,6 @@ static LogicalResult executeCompiledCorePlan(
case CompiledCoreOpKind::Load: case CompiledCoreOpKind::Load:
coreCodeGen.codeGenLoadOp(cast<pim::PimMemCopyHostToDevOp>(node.op), knowledge); coreCodeGen.codeGenLoadOp(cast<pim::PimMemCopyHostToDevOp>(node.op), knowledge);
break; break;
case CompiledCoreOpKind::LoadBatch:
coreCodeGen.codeGenLoadBatchOp(cast<pim::PimMemCopyHostToDevBatchOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Store: case CompiledCoreOpKind::Store:
coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge); coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge);
break; break;
@@ -1213,17 +1197,18 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
auto linkCoreWeights = auto linkCoreWeights =
[&](size_t coreId, ArrayRef<std::string> weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes { [&](size_t coreId, ArrayRef<std::string> weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { if (auto error = sys::fs::create_directory(coreWeightsDirPath); error && error != std::errc::file_exists) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess; return InvalidOutputFileAccess;
} }
for (auto [slot, fileName] : llvm::enumerate(weightFiles)) { for (auto [slot, fileName] : llvm::enumerate(weightFiles)) {
xbarsPerGroup.push_back(static_cast<int64_t>(slot)); xbarsPerGroup.push_back(static_cast<int64_t>(slot));
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, std::string sourcePath = outputDirPath + "/weights/" + fileName;
coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin")) { std::string targetPath = coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin";
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " sys::fs::remove(targetPath);
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin") << "\nError:" << error.message() if (auto error = sys::fs::create_link(sourcePath, targetPath)) {
errs() << "Error creating link file: " << sourcePath << " to " << targetPath << "\nError:" << error.message()
<< '\n'; << '\n';
return InvalidOutputFileAccess; return InvalidOutputFileAccess;
} }
-1
View File
@@ -175,7 +175,6 @@ public:
} }
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const; void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
@@ -1,11 +1,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <functional> #include <functional>
#include "IndexingUtils.hpp" #include "IndexingUtils.hpp"
@@ -20,35 +17,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
APInt lhsConst;
if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero())
return rhs;
APInt rhsConst;
if (matchPattern(rhs, m_ConstantInt(&rhsConst)) && rhsConst.isZero())
return lhs;
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
}
static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatternRewriter& rewriter, Location loc) {
APInt factorConst;
if (auto attr = dyn_cast<Attribute>(factor))
factorConst = cast<IntegerAttr>(attr).getValue();
else if (!matchPattern(cast<Value>(factor), m_ConstantInt(&factorConst)))
return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult();
if (factorConst.isZero())
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
if (factorConst.isOne())
return value;
auto factorValue =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), factorConst.getSExtValue());
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
}
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) { bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
} }
@@ -124,39 +92,6 @@ SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int
return sizes; return sizes;
} }
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()
|| sourceType.getRank() != resultType.getRank())
return false;
for (OpFoldResult stride : strides) {
APInt strideValue;
if (auto attr = dyn_cast<Attribute>(stride)) {
if (cast<IntegerAttr>(attr).getInt() != 1)
return false;
continue;
}
if (!matchPattern(cast<Value>(stride), m_ConstantInt(&strideValue)) || !strideValue.isOne())
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(resultType.getShape().rbegin(), resultType.getShape().rend()),
llvm::make_range(sourceType.getShape().rbegin(), sourceType.getShape().rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize == sizesAndShape.end())
return true;
++firstDifferentSize;
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
auto [size, _dimension] = sizeAndShape;
return size == 1;
});
}
SmallVector<Value> sliceTensor( SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice); ArrayRef<long> shape = getTensorShape(tensorToSlice);
@@ -222,90 +157,6 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri
return slicesPerCore; return slicesPerCore;
} }
Value materializeContiguousTensorSlice(Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> strides,
ConversionPatternRewriter& rewriter,
Location loc) {
assert(resultType.hasStaticShape() && "expected static result type");
size_t rank = static_cast<size_t>(resultType.getRank());
assert(offsets.size() == rank && "expected rank-matching offsets");
assert(strides.size() == rank && "expected rank-matching strides");
SmallVector<OpFoldResult> sizes;
sizes.reserve(resultType.getRank());
for (int64_t size : resultType.getShape())
sizes.push_back(rewriter.getIndexAttr(size));
if (isContiguousTensorSlice(source, resultType, strides))
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
if (resultType.getRank() == 0)
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
SmallVector<Value> zeroIndices(resultType.getRank());
for (Value& zeroIndex : zeroIndices)
zeroIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
SmallVector<Value> resultIndices;
resultIndices.reserve(resultType.getRank());
auto buildLoopNest = [&](auto&& self, unsigned dim, Value accumulator) -> Value {
if (dim == resultType.getRank()) {
SmallVector<Value> sourceIndices;
sourceIndices.reserve(resultType.getRank());
for (unsigned idx = 0; idx < resultType.getRank(); ++idx) {
Value offsetValue = getOrMaterializeIndexValue(rewriter, offsets[idx]);
Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc);
sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc));
}
SmallVector<OpFoldResult> sourceOffsets;
SmallVector<OpFoldResult> destinationOffsets;
SmallVector<OpFoldResult> unitSizes;
SmallVector<OpFoldResult> unitStrides;
sourceOffsets.reserve(resultType.getRank());
destinationOffsets.reserve(resultType.getRank());
unitSizes.reserve(resultType.getRank());
unitStrides.reserve(resultType.getRank());
for (Value index : sourceIndices)
sourceOffsets.push_back(index);
for (Value index : resultIndices)
destinationOffsets.push_back(index);
for (int64_t idx = 0; idx < resultType.getRank(); ++idx) {
unitSizes.push_back(rewriter.getIndexAttr(1));
unitStrides.push_back(rewriter.getIndexAttr(1));
}
auto elementTensorType =
RankedTensorType::get(SmallVector<int64_t>(resultType.getRank(), 1), resultType.getElementType());
Value elementSlice =
tensor::ExtractSliceOp::create(rewriter, loc, elementTensorType, source, sourceOffsets, unitSizes, unitStrides)
.getResult();
return tensor::InsertSliceOp::create(
rewriter, loc, elementSlice, accumulator, destinationOffsets, unitSizes, unitStrides)
.getResult();
}
Value lower = zeroIndices[dim];
Value upper =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody());
resultIndices.push_back(loop.getInductionVar());
Value updated = self(self, dim + 1, loop.getRegionIterArgs().front());
resultIndices.pop_back();
scf::YieldOp::create(rewriter, loc, updated);
rewriter.setInsertionPointAfter(loop);
return loop.getResult(0);
};
return buildLoopNest(buildLoopNest, 0, init);
}
Value extractAxisSlice( Value extractAxisSlice(
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) { PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
auto sourceType = cast<RankedTensorType>(source.getType()); auto sourceType = cast<RankedTensorType>(source.getType());
@@ -108,13 +108,6 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore( llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
mlir::Value materializeContiguousTensorSlice(mlir::Value source,
mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets,
llvm::ArrayRef<mlir::OpFoldResult> strides,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value extractAxisSlice( mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size); mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
@@ -303,9 +303,11 @@ createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewri
static Value extractDynamicGemmBColumn( static Value extractDynamicGemmBColumn(
Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column}; SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType()); auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc); Value columnSlice =
tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides).getResult();
SmallVector<ReassociationIndices> collapseReassociation { SmallVector<ReassociationIndices> collapseReassociation {
ReassociationIndices {0, 1} ReassociationIndices {0, 1}
}; };
@@ -23,7 +23,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { static Value materializeTileTensor(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType()); auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank())); return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
@@ -319,7 +319,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value windowValue = Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides); tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue); windowValue = materializeTileTensor(rewriter, loc, windowValue);
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue); reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
} }
} }
@@ -335,7 +335,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create( Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides); rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice); scaleSlice = materializeTileTensor(rewriter, loc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice); reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
} }
@@ -59,10 +59,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
ShapedType destinationType, ShapedType destinationType,
IRMapping& mapper) { IRMapping& mapper) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType())); int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides(destinationType.getRank(), 1); SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
ArrayRef<int64_t> shape = destinationType.getShape();
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
Value totalOffset; Value totalOffset;
Location loc = insertSlice.getLoc(); Location loc = insertSlice.getLoc();
@@ -162,14 +159,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex); BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
auto newArgType = cast<ShapedType>(newArg.getType()); auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
loc, auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
outputBuffer.getType(), loc,
outputBuffer, outputBuffer.getType(),
newArg, zeroOffset,
rewriter.getI32IntegerAttr(0), zeroOffset,
rewriter.getI32IntegerAttr(0), outputBuffer,
getTensorSizeInBytesAttr(rewriter, newArg)) newArg,
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput(); .getOutput();
mapper.map(*oldArg, copied); mapper.map(*oldArg, copied);
} }
@@ -233,14 +231,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
} }
auto clonedType = cast<ShapedType>(clonedTensor.getType()); auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
loc, auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
outputBuffer.getType(), loc,
outputBuffer, outputBuffer.getType(),
clonedTensor, zeroOffset,
rewriter.getI32IntegerAttr(0), zeroOffset,
rewriter.getI32IntegerAttr(0), outputBuffer,
getTensorSizeInBytesAttr(rewriter, clonedTensor)) clonedTensor,
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput(); .getOutput();
mapper.map(toTensorOp.getResult(), copied); mapper.map(toTensorOp.getResult(), copied);
continue; continue;
@@ -1,10 +1,8 @@
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include <cassert> #include <cassert>
#include <cstddef>
#include "Common.hpp" #include "Common.hpp"
@@ -13,48 +11,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/*
EXAMPLE RUN:
[1, 10, 3, 4] inputShape
[0, 2, 1, 3] offsets
acc = 1
---
ret = 3
acc = 4
---
ret = 3 + 4 * 1 = 7
acc = 12
---
ret = 7 + 12 * 2 = 31
acc = 120
---
ret = 31 + 120 * 0 = 31
acc = 120
*/
size_t returnValue = 0;
auto sliceOffsets = sliceOp.getStaticOffsets();
auto inputDimSizes = inputShape.getShape();
assert(sliceOffsets.size() == inputDimSizes.size());
size_t accumulatedDimensionSize = 1;
// Reverse iterate the two vectors
for (auto it : reverse(zip(sliceOffsets, inputDimSizes))) {
auto curSliceOffset = std::get<0>(it);
auto curInputDimSize = std::get<1>(it);
returnValue += accumulatedDimensionSize * curSliceOffset;
accumulatedDimensionSize *= curInputDimSize;
}
return returnValue;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType())))); return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
} }
@@ -6,20 +6,6 @@
namespace onnx_mlir { namespace onnx_mlir {
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
*
* The static offsets represent the starting position of the slice in each
* dimension, while the static tensor input gives its dimension size.
*
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
* calculated.
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp.
*/
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
template <class T> template <class T>
@@ -83,18 +83,6 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter,
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0); auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType))); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
return PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
tensorType,
outputBuffer,
zeroValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
sizeAttr)
.getOutput();
return PimMemCopyHostToDevOp::create( return PimMemCopyHostToDevOp::create(
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr) rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
.getOutput(); .getOutput();
-26
View File
@@ -144,32 +144,6 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
}]; }];
} }
def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> {
let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core";
let arguments = (ins
PimTensor:$deviceTarget,
PimTensor:$hostSource,
I32Attr:$deviceTargetOffset,
I32Attr:$hostSourceOffset,
I32Attr:$size
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getDeviceTargetMutable();
}
}];
let assemblyFormat = [{
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
}];
}
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> { def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from device memory into host memory"; let summary = "Copy a memory region from device memory into host memory";
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp"
@@ -10,8 +11,8 @@ using namespace bufferization;
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { Value materializeContiguousInputMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue))) if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue)))
return memrefValue; return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType()); auto shapedType = cast<ShapedType>(memrefValue.getType());
@@ -30,7 +31,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
.getOutput(); .getOutput();
} }
Value allocateContiguousMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) { Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue))) if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue; return memrefValue;
@@ -5,8 +5,9 @@
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); mlir::Value materializeContiguousInputMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
mlir::Value allocateContiguousMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); mlir::Value
allocateContiguousResultMemRefLike(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter);
llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter, llvm::FailureOr<mlir::Value> getBufferOrValue(mlir::RewriterBase& rewriter,
mlir::Value value, mlir::Value value,
@@ -6,6 +6,8 @@ add_pim_library(OMPimBufferization
PimBufferizationPass.cpp PimBufferizationPass.cpp
BufferizationUtils.hpp BufferizationUtils.hpp
BufferizationUtils.cpp BufferizationUtils.cpp
ContiguityPatterns.hpp
ContiguityPatterns.cpp
OpBufferizationInterfaces.hpp OpBufferizationInterfaces.hpp
OpBufferizationInterfaces.cpp OpBufferizationInterfaces.cpp
Common.hpp Common.hpp
@@ -0,0 +1,343 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir::pim {
namespace {
static bool isStaticSubviewContiguous(const StaticSubviewInfo& info) {
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
return false;
return isContiguousSubviewWithDynamicOffsets(info.sourceShape, info.offsets, info.sizes, info.strides);
}
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
if (extraOffset == 0)
return baseOffset;
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
assert(integerAttr && "expected integer offset attribute");
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
}
auto value = cast<Value>(baseOffset);
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
}
static Value buildSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
ArrayRef<int64_t> shape,
ArrayRef<int64_t> strides,
PatternRewriter& rewriter) {
SmallVector<Value> indices;
indices.reserve(shape.size());
Value remaining = linearIndex;
for (auto [_dim, stride] : llvm::enumerate(strides)) {
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
indices.push_back(index);
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
}
return indices;
}
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = cast<IntegerAttr>(attr);
if (integerAttr.getInt() == 0)
return extraOffset;
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
}
auto value = cast<Value>(baseOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
}
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<Value> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
if (dim + 1 < info.sizes.size()) {
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(1));
}
else {
chunkOffsets.push_back(info.offsets[dim]);
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
}
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
static Value buildContiguousChunk(
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(copyShape.size());
chunkSizes.reserve(copyShape.size());
chunkStrides.reserve(copyShape.size());
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
chunkStrides.push_back(rewriter.getIndexAttr(1));
}
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
}
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
Value src,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
bool allowLoopRewrite,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview) && !isStaticSubviewContiguous(*srcSubview);
const bool splitDst = succeeded(dstSubview) && !isStaticSubviewContiguous(*dstSubview);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(src.getType());
auto dstType = dyn_cast<MemRefType>(dst.getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
if (!hasByteSizedElementType(sourceType.getElementType()))
return failure();
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
const bool sourceShapeMatchesCopyShape = llvm::equal(sourceType.getShape(), copyShape);
const bool dstShapeMatchesCopyShape = llvm::equal(dstType.getShape(), copyShape);
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
&& dstType.getRank() == static_cast<int64_t>(copyShape.size()) && (splitSrc || sourceShapeMatchesCopyShape)
&& (splitDst || dstShapeMatchesCopyShape)) {
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices =
outerShape.empty() ? SmallVector<Value> {}
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
return success();
}
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
}
return success();
}
struct NormalizeCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>() && !copyOp->getParentOfType<pim::PimCoreBatchOp>())
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getTarget());
return success();
}
};
struct NormalizeHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
return success();
}
};
struct NormalizeHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getHostTarget(),
copyOp.getDeviceSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/false,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
return success();
}
};
} // namespace
void populatePimContiguityNormalizationPatterns(RewritePatternSet& patterns) {
patterns.add<NormalizeCoreSubviewCopyPattern, NormalizeHostSubviewLoadPattern, NormalizeHostSubviewStorePattern>(
patterns.getContext());
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir::pim {
void populatePimContiguityNormalizationPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir::pim
@@ -47,32 +47,6 @@ struct MemCopyHostToDevOpInterface
} }
}; };
struct MemCopyHostToDevBatchOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevBatchOpInterface, PimMemCopyHostToDevBatchOp> {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
auto deviceTargetOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
if (failed(deviceTargetOpt))
return failure();
auto hostSourceOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
if (failed(hostSourceOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevBatchOp>(rewriter,
memCopyHostToDevOp,
deviceTargetOpt->getType(),
*deviceTargetOpt,
*hostSourceOpt,
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
memCopyHostToDevOp.getHostSourceOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
};
struct MemCopyDevToHostOpInterface struct MemCopyDevToHostOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
@@ -151,8 +125,9 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimReceiveOp>( replaceOpWithNewBufferizedOp<PimReceiveOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId()); rewriter, op, contiguousOutput.getType(), contiguousOutput, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
return success(); return success();
} }
}; };
@@ -173,15 +148,16 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
auto inputOpt = getBufferOrValue(rewriter, input, options, state); auto inputOpt = getBufferOrValue(rewriter, input, options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter)); inputs.push_back(materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter));
} }
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state); auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimConcatOp>( replaceOpWithNewBufferizedOp<PimConcatOp>(
rewriter, op, outputBufferOpt->getType(), concatOp.getAxisAttr(), ValueRange(inputs), *outputBufferOpt); rewriter, op, contiguousOutput.getType(), concatOp.getAxisAttr(), ValueRange(inputs), contiguousOutput);
return success(); return success();
} }
}; };
@@ -206,7 +182,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
replaceOpWithNewBufferizedOp<PimSendOp>(rewriter, replaceOpWithNewBufferizedOp<PimSendOp>(rewriter,
op, op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(), sendOp.getSizeAttr(),
sendOp.getTargetCoreId()); sendOp.getTargetCoreId());
return success(); return success();
@@ -431,8 +407,8 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter); Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimTransposeOp>( replaceOpWithNewBufferizedOp<PimTransposeOp>(
rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput); rewriter, op, contiguousOutput.getType(), contiguousInput, transposeOp.getPermutation(), contiguousOutput);
@@ -475,8 +451,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter); Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>( replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput); rewriter, op, contiguousOutput.getType(), *weightOpt, contiguousInput, contiguousOutput);
@@ -514,9 +490,9 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter); Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter); Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>( replaceOpWithNewBufferizedOp<OpTy>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
@@ -547,9 +523,9 @@ struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInter
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter); Value contiguousLhs = materializeContiguousInputMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter); Value contiguousRhs = materializeContiguousInputMemRef(*rhsOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVVDMulOp>( replaceOpWithNewBufferizedOp<PimVVDMulOp>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
@@ -583,8 +559,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter); Value contiguousInput = materializeContiguousInputMemRef(*inputOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); Value contiguousOutput = allocateContiguousResultMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput); replaceOpWithNewBufferizedOp<OpTy>(rewriter, op, contiguousOutput.getType(), contiguousInput, contiguousOutput);
return success(); return success();
@@ -599,7 +575,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimSendOp::attachInterface<SendOpInterface>(*ctx); PimSendOp::attachInterface<SendOpInterface>(*ctx);
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx); PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx); PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx); PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx); PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx); PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
@@ -6,14 +6,14 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/ContiguityPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
@@ -40,6 +40,7 @@ struct PimBufferizationPass : PassWrapper<PimBufferizationPass, OperationPass<Mo
private: private:
void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const; void annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const;
LogicalResult verifyContiguousRuntimeOperands(ModuleOp moduleOp) const;
}; };
} // namespace } // namespace
@@ -84,6 +85,20 @@ void PimBufferizationPass::runOnOperation() {
return; return;
} }
// After this pass, executable PIM ops must only use contiguous/addressable memrefs.
// Later PIM codegen passes may verify this invariant but must not repair it.
RewritePatternSet contiguityPatterns(ctx);
populatePimContiguityNormalizationPatterns(contiguityPatterns);
if (failed(applyPatternsGreedily(moduleOp, std::move(contiguityPatterns)))) {
moduleOp.emitError("failed to normalize PIM runtime operand contiguity during bufferization");
signalPassFailure();
return;
}
if (failed(verifyContiguousRuntimeOperands(moduleOp))) {
signalPassFailure();
return;
}
annotateWeightsMemrefs(moduleOp, funcOp); annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug // Dump to file for debug
@@ -108,6 +123,75 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); }); funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
} }
LogicalResult PimBufferizationPass::verifyContiguousRuntimeOperands(ModuleOp moduleOp) const {
bool hasFailure = false;
moduleOp.walk([&](Operation* op) {
auto verifyOperand = [&](Value operand, unsigned operandIndex) {
if (!isa<BaseMemRefType>(operand.getType()))
return;
if (succeeded(resolveContiguousAddress(operand)) || succeeded(compileContiguousAddressExpr(operand)))
return;
op->emitOpError() << "operand #" << operandIndex
<< " is not backed by contiguous addressable storage after PIM bufferization";
hasFailure = true;
};
if (auto memCopyOp = dyn_cast<PimMemCopyOp>(op)) {
verifyOperand(memCopyOp.getTarget(), 0);
verifyOperand(memCopyOp.getSource(), 1);
return;
}
if (auto loadOp = dyn_cast<PimMemCopyHostToDevOp>(op)) {
verifyOperand(loadOp.getDeviceTarget(), 2);
verifyOperand(loadOp.getHostSource(), 3);
return;
}
if (auto storeOp = dyn_cast<PimMemCopyDevToHostOp>(op)) {
verifyOperand(storeOp.getHostTarget(), 2);
verifyOperand(storeOp.getDeviceSource(), 3);
return;
}
if (auto sendOp = dyn_cast<PimSendOp>(op)) {
verifyOperand(sendOp.getInput(), 0);
return;
}
if (auto receiveOp = dyn_cast<PimReceiveOp>(op)) {
verifyOperand(receiveOp.getOutputBuffer(), 0);
return;
}
if (auto concatOp = dyn_cast<PimConcatOp>(op)) {
verifyOperand(concatOp.getOutputBuffer(), 0);
for (auto inputAndIndex : llvm::enumerate(concatOp.getInputs()))
verifyOperand(inputAndIndex.value(), inputAndIndex.index() + 1);
return;
}
if (isa<PimTransposeOp,
PimVMMOp,
PimVVAddOp,
PimVVSubOp,
PimVVMulOp,
PimVVMaxOp,
PimVVDMulOp,
PimVAvgOp,
PimVReluOp,
PimVTanhOp,
PimVSigmOp,
PimVSoftmaxOp>(op)) {
for (auto operandAndIndex : llvm::enumerate(op->getOperands())) {
if (auto vmmOp = dyn_cast<PimVMMOp>(op); vmmOp && operandAndIndex.index() == 0)
continue;
verifyOperand(operandAndIndex.value(), operandAndIndex.index());
}
}
});
if (hasFailure) {
moduleOp.emitError("PIM bufferization must fully normalize executable runtime operand contiguity before codegen");
return failure();
}
return success();
}
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); } std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -22,6 +22,10 @@ using namespace onnx_mlir::compact_asm;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
// This pass assumes bufferization has already normalized executable PIM
// operands. It only reuses compatible local allocations with non-overlapping
// lifetimes; it does not repair memory contiguity.
struct CoalescingReportRow { struct CoalescingReportRow {
uint64_t numCandidates = 0; uint64_t numCandidates = 0;
uint64_t numSkipped = 0; uint64_t numSkipped = 0;
@@ -1,349 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "../Common.hpp" #include "../Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static bool isSubviewContiguous(const StaticSubviewInfo& info) {
if (llvm::any_of(info.strides, [](int64_t stride) { return stride != 1; }))
return false;
auto sizesAndShape = llvm::zip_equal(llvm::make_range(info.sizes.rbegin(), info.sizes.rend()),
llvm::make_range(info.sourceShape.rbegin(), info.sourceShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize == sizesAndShape.end())
return true;
++firstDifferentSize;
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
auto [size, _dimension] = sizeAndShape;
return size == 1;
});
}
static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffset, PatternRewriter& rewriter) {
if (extraOffset == 0)
return baseOffset;
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
assert(integerAttr && "expected integer offset attribute");
return rewriter.getIndexAttr(integerAttr.getInt() + extraOffset);
}
auto value = cast<Value>(baseOffset);
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult();
}
static Value buildSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<int64_t> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
int64_t extraOffset = dim + 1 < info.sizes.size() ? outerIndices[dim] * info.strides[dim] : 0;
chunkOffsets.push_back(addConstantOffset(info.offsets[dim], extraOffset, rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < info.sizes.size() ? 1 : info.sizes.back()));
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
ArrayRef<int64_t> shape,
ArrayRef<int64_t> strides,
PatternRewriter& rewriter) {
SmallVector<Value> indices;
indices.reserve(shape.size());
Value remaining = linearIndex;
for (auto [_dim, stride] : llvm::enumerate(strides)) {
auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride);
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
indices.push_back(index);
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
}
return indices;
}
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
auto integerAttr = cast<IntegerAttr>(attr);
if (integerAttr.getInt() == 0)
return extraOffset;
auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt());
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
}
auto value = cast<Value>(baseOffset);
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
}
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
ArrayRef<Value> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(info.offsets.size());
chunkSizes.reserve(info.sizes.size());
chunkStrides.reserve(info.strides.size());
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
if (dim + 1 < info.sizes.size()) {
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(1));
}
else {
chunkOffsets.push_back(info.offsets[dim]);
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
}
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
}
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
}
static Value buildContiguousChunk(
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides;
chunkOffsets.reserve(copyShape.size());
chunkSizes.reserve(copyShape.size());
chunkStrides.reserve(copyShape.size());
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
chunkStrides.push_back(rewriter.getIndexAttr(1));
}
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
}
template <typename CopyOp, typename CreateCopyOp>
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
Value dst,
Value src,
int64_t dstOffset,
int64_t srcOffset,
int64_t size,
bool allowLoopRewrite,
PatternRewriter& rewriter,
CreateCopyOp createCopyOp) {
auto srcSubview = getStaticSubviewInfo(src);
auto dstSubview = getStaticSubviewInfo(dst);
const bool splitSrc = succeeded(srcSubview) && !isSubviewContiguous(*srcSubview);
const bool splitDst = succeeded(dstSubview) && !isSubviewContiguous(*dstSubview);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(src.getType());
auto dstType = dyn_cast<MemRefType>(dst.getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && (srcOffset != 0 || llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
if (splitDst && (dstOffset != 0 || llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
if (!hasByteSizedElementType(sourceType.getElementType()))
return failure();
const int64_t elementByteWidth = static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (size != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0);
auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices);
auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1);
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices =
outerShape.empty() ? SmallVector<Value> {}
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
return success();
}
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t> {} : delinearizeIndex(linearIndex, outerShape, outerStrides);
Value chunkDst = splitDst ? buildSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) : dst;
Value chunkSrc = splitSrc ? buildSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) : src;
const int64_t srcByteOffset = splitSrc ? 0 : srcOffset + linearIndex * sliceBytes;
const int64_t dstByteOffset = splitDst ? 0 : dstOffset + linearIndex * sliceBytes;
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, dstByteOffset, srcByteOffset, sliceBytes);
}
return success();
}
// Splits core copies through subviews into contiguous copy chunks for codegen.
struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
if (!copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getTarget(),
copyOp.getSource(),
copyOp.getTargetOffset(),
copyOp.getSourceOffset(),
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
pim::PimMemCopyOp::create(rewriter,
copyOp.getLoc(),
resultType,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getTarget());
return success();
}
};
// Splits host-to-device subview loads into contiguous copy chunks.
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getDeviceTarget(),
copyOp.getHostSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/true,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffsetValue = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffsetValue,
srcOffsetValue,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getDeviceTarget());
return success();
}
};
// Splits device-to-host subview stores into contiguous copy chunks.
struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp(
copyOp,
copyOp.getHostTarget(),
copyOp.getDeviceSource(),
*dstOffset,
*srcOffset,
copyOp.getSize(),
/*allowLoopRewrite=*/false,
rewriter,
[&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffset = getOrCreateIndexConstant(rewriter, copyOp, dstByteOffset);
Value srcOffset = getOrCreateIndexConstant(rewriter, copyOp, srcByteOffset);
pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(),
resultType,
dstOffset,
srcOffset,
dst,
src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
});
if (failed(status))
return failure();
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
return success();
}
};
// Folds constant subviews used as core weights into standalone globals. // Folds constant subviews used as core weights into standalone globals.
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> { struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
@@ -392,10 +55,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
} // namespace } // namespace
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) { void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
patterns.add<RewriteCoreSubviewCopyPattern, patterns.add<FoldConstantCoreSubviewPattern>(patterns.getContext());
RewriteHostSubviewLoadPattern,
RewriteHostSubviewStorePattern,
FoldConstantCoreSubviewPattern>(patterns.getContext());
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -7,12 +7,9 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/MathExtras.h" #include "llvm/Support/MathExtras.h"
#include <type_traits>
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -85,31 +82,18 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp,
if (contiguousType != originalType) if (contiguousType != originalType)
deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc); deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc);
Value copiedValue; Value zeroOffset = getOrCreateIndexConstant(constantFolder, op, 0);
if constexpr (std::is_same_v<CoreOpTy, pim::PimCoreBatchOp>) { Value hostOffset = getOrCreateIndexConstant(constantFolder, op, resolvedAddress->byteOffset);
copiedValue = pim::PimMemCopyHostToDevBatchOp::create( Value copiedValue =
rewriter, pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(), op->getLoc(),
originalType, originalType,
deviceDst, zeroOffset,
getGlobalOp.getResult(), hostOffset,
rewriter.getI32IntegerAttr(0), deviceDst,
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)), getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes))) rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
.getOutput(); .getOutput();
}
else {
copiedValue = pim::PimMemCopyHostToDevOp::create(
rewriter,
op->getLoc(),
originalType,
getOrCreateIndexConstant(constantFolder, op, 0),
getOrCreateIndexConstant(constantFolder, op, static_cast<int64_t>(resolvedAddress->byteOffset)),
deviceDst,
getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
.getOutput();
}
cachedByType[originalType] = copiedValue; cachedByType[originalType] = copiedValue;
operand.set(copiedValue); operand.set(copiedValue);
+22 -66
View File
@@ -32,51 +32,30 @@ static bool isAddressOnlyHostOp(Operation* op) {
memref::CopyOp>(op); memref::CopyOp>(op);
} }
// Looser than isCodegenAddressableValue: follows view ops without requiring contiguity.
// Used for memref.copy operands which may be non-contiguous subviews.
static bool isBaseAddressableValue(Value value) {
while (true) {
if (isa<BlockArgument>(value))
return true;
Operation* defOp = value.getDefiningOp();
if (!defOp)
return false;
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
return true;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
value = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
value = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
value = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
value = expand.getSrc();
continue;
}
return false;
}
}
static bool isCodegenAddressableValue(Value value) { static bool isCodegenAddressableValue(Value value) {
auto resolvedAddress = resolveContiguousAddress(value); auto resolvedAddress = resolveContiguousAddress(value);
if (failed(resolvedAddress)) if (succeeded(resolvedAddress))
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false; return false;
return isa<BlockArgument>(resolvedAddress->base) return isa<BlockArgument>(compiledAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp()); || isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
} }
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) { static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
auto resolvedAddress = resolveContiguousAddress(value, knowledge); auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (failed(resolvedAddress)) if (succeeded(resolvedAddress))
return isa<BlockArgument>(resolvedAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
auto compiledAddress = compileContiguousAddressExpr(value);
if (failed(compiledAddress))
return false; return false;
return isa<BlockArgument>(resolvedAddress->base) return isa<BlockArgument>(compiledAddress->base)
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp()); || isa<memref::AllocOp, memref::GetGlobalOp>(compiledAddress->base.getDefiningOp());
} }
static bool isConstantGlobalView(Value value) { static bool isConstantGlobalView(Value value) {
@@ -138,7 +117,6 @@ static bool isCoreWeightBlockArgument(Value value) {
static bool isSupportedCoreInstructionOp(Operation* op) { static bool isSupportedCoreInstructionOp(Operation* op) {
return isa<pim::PimMemCopyHostToDevOp, return isa<pim::PimMemCopyHostToDevOp,
pim::PimMemCopyHostToDevBatchOp,
pim::PimMemCopyDevToHostOp, pim::PimMemCopyDevToHostOp,
pim::PimMemCopyOp, pim::PimMemCopyOp,
pim::PimReceiveOp, pim::PimReceiveOp,
@@ -159,27 +137,6 @@ static bool isSupportedCoreInstructionOp(Operation* op) {
memref::GetGlobalOp>(op); memref::GetGlobalOp>(op);
} }
static LogicalResult verifyBatchOpSemantics(Operation& op,
const StaticValueKnowledge& knowledge,
pim::CappedDiagnosticReporter& diagnostics) {
bool hasFailure = false;
auto reportFailure = [&](auto emitDiagnostic) {
diagnostics.report(&op, [&](Operation* illegalOp) { emitDiagnostic(illegalOp); });
hasFailure = true;
};
if (auto memcpHdBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
if (!isCodegenAddressableValue(memcpHdBatchOp.getHostSource(), knowledge)) {
reportFailure([](Operation* illegalOp) {
illegalOp->emitOpError("host operand #1 is not backed by contiguous addressable storage");
});
}
return success(!hasFailure);
}
return success(!hasFailure);
}
struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> { struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VerificationPass)
@@ -343,7 +300,7 @@ private:
} }
auto resolvedAddress = resolveContiguousAddress(operand, knowledge); auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
if (failed(resolvedAddress)) { if (failed(resolvedAddress) && failed(compileContiguousAddressExpr(operand))) {
diagnostics.report(&op, [&](Operation* illegalOp) { diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "operand #" << operandIndex illegalOp->emitOpError() << "operand #" << operandIndex
<< " is not backed by contiguous addressable storage"; << " is not backed by contiguous addressable storage";
@@ -363,7 +320,9 @@ private:
continue; continue;
} }
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) { Value addressBase =
succeeded(resolvedAddress) ? resolvedAddress->base : compileContiguousAddressExpr(operand)->base;
if (!isa<memref::AllocOp>(addressBase.getDefiningOp())) {
diagnostics.report(&op, [&](Operation* illegalOp) { diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << "operand #" << operandIndex illegalOp->emitOpError() << "operand #" << operandIndex
<< " must be backed by device-local memory; materialize host values with " << " must be backed by device-local memory; materialize host values with "
@@ -392,9 +351,6 @@ private:
hasFailure = true; hasFailure = true;
} }
} }
if (failed(verifyBatchOpSemantics(op, knowledge, diagnostics)))
hasFailure = true;
return success(!hasFailure); return success(!hasFailure);
}); });
} }
@@ -409,7 +365,7 @@ private:
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op)) if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics); return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) { if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) { if (!isMemRefBaseAddressableValue(copyOp.getSource()) || !isMemRefBaseAddressableValue(copyOp.getTarget())) {
diagnostics.report(op, [](Operation* illegalOp) { diagnostics.report(op, [](Operation* illegalOp) {
illegalOp->emitOpError("depends on a value that is not backed by addressable storage"); illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
}); });
@@ -432,7 +388,7 @@ private:
} }
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) { static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
if (isBaseAddressableValue(source)) if (isMemRefBaseAddressableValue(source))
return success(); return success();
diagnostics.report(op, [](Operation* illegalOp) { diagnostics.report(op, [](Operation* illegalOp) {