add pim.vmm verifier and fix vmm lowering
reuse code for subviews
This commit is contained in:
@@ -3,6 +3,7 @@ add_pim_library(OMPimCommon
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
IR/SubviewUtils.cpp
|
||||
IR/WeightUtils.cpp
|
||||
Support/DebugDump.cpp
|
||||
Support/Diagnostics.cpp
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
}
|
||||
|
||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(info.offsets.size());
|
||||
for (OpFoldResult offset : info.offsets) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
staticOffsets.push_back(*staticOffset);
|
||||
}
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -24,8 +24,8 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/IR/CompactAsmUtils.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
@@ -206,9 +206,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
|
||||
return iter->second.address + resolvedAddress->byteOffset;
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::reportHost() {
|
||||
hostReportRow = hostMem.getReportRow();
|
||||
}
|
||||
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
|
||||
|
||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
||||
@@ -810,7 +808,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
||||
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
||||
else {
|
||||
op.emitError("Unsupported codegen for this operation");
|
||||
InFlightDiagnostic diag = op.emitError()
|
||||
<< "unsupported codegen for op '" << op.getName().getStringRef() << "'";
|
||||
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
|
||||
diag << " inside pim.core " << coreOp.getCoreId();
|
||||
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
|
||||
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
|
||||
return failure();
|
||||
}
|
||||
processedOperations++;
|
||||
@@ -935,9 +938,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
if (auto err = emitCore(coreOp, false))
|
||||
return err;
|
||||
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
|
||||
.getReportRow());
|
||||
memory.recordCoreReport(
|
||||
emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId()))).getReportRow());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ struct MemoryReportRow {
|
||||
|
||||
bool operator==(const MemoryReportRow& other) const {
|
||||
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
|
||||
&& sizeGlobal == other.sizeGlobal;
|
||||
&& sizeGlobal == other.sizeGlobal;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
@@ -30,19 +32,6 @@ struct DenseWeightView {
|
||||
int64_t offset = 0;
|
||||
};
|
||||
|
||||
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
|
||||
strides[index] = strides[index + 1] * shape[index + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
bool allStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
}
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<memref::SubViewOp> subviews;
|
||||
mlir::Value current = weight;
|
||||
@@ -55,7 +44,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||
break;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!allStaticSubviewParts(subview))
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
subviews.push_back(subview);
|
||||
current = subview.getSource();
|
||||
@@ -79,7 +68,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
DenseWeightView view;
|
||||
view.denseAttr = denseAttr;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStridesForShape(view.shape);
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
@@ -47,8 +47,8 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value createPoolFillElement(
|
||||
ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||
static Value
|
||||
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||
if (!useMinimumValue)
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||
|
||||
@@ -65,8 +65,10 @@ static Value createPoolFillElement(
|
||||
llvm_unreachable("unsupported pool element type");
|
||||
}
|
||||
|
||||
static Value createPoolFillTensor(
|
||||
ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) {
|
||||
static Value createPoolFillTensor(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
bool useMinimumValue) {
|
||||
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
|
||||
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
|
||||
}
|
||||
@@ -90,10 +92,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
||||
inputType.getDimSize(3) + padLeft + padRight},
|
||||
inputType.getElementType(),
|
||||
inputType.getEncoding());
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padTop),
|
||||
rewriter.getIndexAttr(padLeft)};
|
||||
SmallVector<OpFoldResult> lowPads = {
|
||||
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padBottom),
|
||||
@@ -104,8 +104,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
Value padValue = createPoolFillElement(
|
||||
rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
Value padValue =
|
||||
createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
tensor::YieldOp::create(rewriter, loc, padValue);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
@@ -279,7 +279,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
||||
Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||
Value paddedInput =
|
||||
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
@@ -307,8 +308,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
Value reducedWindow = createPoolFillTensor(
|
||||
rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
Value reducedWindow =
|
||||
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
Value paddedInH = windowBaseH;
|
||||
@@ -324,18 +325,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex,
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
paddedInH,
|
||||
paddedInW};
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||
@@ -344,36 +341,28 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
outHeightIndex,
|
||||
outWidthIndex};
|
||||
SmallVector<OpFoldResult> scaleOffsets = {
|
||||
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> scaleStrides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
||||
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {batchIndex,
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
outHeightIndex,
|
||||
outWidthIndex};
|
||||
SmallVector<OpFoldResult> outputOffsets = {
|
||||
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> outputStrides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> outputStrides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
updatedOutput = tensor::InsertSliceOp::create(
|
||||
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
||||
}
|
||||
|
||||
@@ -9,12 +9,14 @@
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
@@ -147,6 +149,73 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
||||
rewriter.replaceOp(extractRowsOp, replacements);
|
||||
}
|
||||
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
|
||||
|
||||
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
||||
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
|
||||
continue;
|
||||
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
|
||||
return globalOp;
|
||||
}
|
||||
|
||||
std::string nameStem;
|
||||
llvm::raw_string_ostream nameStream(nameStem);
|
||||
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
|
||||
nameStream.flush();
|
||||
|
||||
std::string symbolName = nameStem;
|
||||
unsigned suffix = 0;
|
||||
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
|
||||
symbolName = (nameStem + "_" + Twine(suffix++)).str();
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getStringAttr(symbolName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
zeroAttr,
|
||||
rewriter.getUnitAttr(),
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||
|
||||
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
||||
return PimMemCopyHostToDevBatchOp::create(
|
||||
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
.getOutput();
|
||||
|
||||
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||
assert(shape[1] <= static_cast<int64_t>(crossbarSize) && "vector width must fit in one crossbar");
|
||||
|
||||
if (shape[1] == static_cast<int64_t>(crossbarSize))
|
||||
return vector;
|
||||
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
@@ -426,54 +495,35 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return;
|
||||
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
|
||||
if (!dpsDefiningOp)
|
||||
return;
|
||||
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return;
|
||||
Value tiedValue = tiedOperand->get();
|
||||
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
|
||||
tiedValue.setType(newType);
|
||||
self(tiedValue, newType, self);
|
||||
};
|
||||
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outTensorOperand = vmmOp.getOutputBuffer();
|
||||
auto resultTensor = vmmOp.getOutput();
|
||||
auto outShape = getTensorShape(outTensorOperand);
|
||||
assert(isHVectorShape(outShape));
|
||||
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
||||
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
||||
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
||||
if (outTensorOperand == vmmOp.getInput()) {
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
auto newOutputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
||||
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
||||
}
|
||||
else {
|
||||
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
||||
outTensorOperand.setType(newType);
|
||||
}
|
||||
resultTensor.setType(newType);
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
assert(isHVectorShape(outputShape) && "expected a horizontal vector output");
|
||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||
|
||||
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
||||
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
|
||||
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
|
||||
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
|
||||
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
|
||||
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
||||
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
||||
rewriter.setInsertionPointAfter(vmmOp);
|
||||
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||
}
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
||||
auto paddedOutputType = RankedTensorType::get(
|
||||
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
||||
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
||||
? vmmOp.getOutputBuffer()
|
||||
: createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult();
|
||||
vmmOp.getInputMutable().assign(paddedInput);
|
||||
vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer);
|
||||
|
||||
vmmOp.getOutput().setType(paddedOutputType);
|
||||
|
||||
if (outputShape[1] == static_cast<int64_t>(crossbarSize))
|
||||
return;
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
rewriter.setInsertionPointAfter(vmmOp);
|
||||
auto sliceOp =
|
||||
tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides);
|
||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||
vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -389,6 +389,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -77,6 +78,22 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
||||
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||
if (weightIndex >= coreBatchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
@@ -104,6 +121,47 @@ LogicalResult PimReceiveTensorBatchOp::verify() {
|
||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimVMMOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
||||
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||
|
||||
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||
if (!vectorType || !outputType)
|
||||
return emitError("input and output must be shaped types");
|
||||
|
||||
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
|
||||
return emitError("matrix dimensions must fit in one crossbar");
|
||||
|
||||
int64_t vector1 = vectorShape[0];
|
||||
int64_t vectorWidth = vectorShape[1];
|
||||
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
|
||||
return emitError("vector shape must be (1, crossbar-size)");
|
||||
|
||||
int64_t output1 = outputShape[0];
|
||||
int64_t outputWidth = outputShape[1];
|
||||
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
|
||||
return emitError("output shape must be (1, crossbar-size)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
@@ -105,6 +105,37 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInterface, PimMemCopyOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyOp = cast<PimMemCopyOp>(op);
|
||||
|
||||
auto targetOpt = getBufferOrValue(rewriter, memCopyOp.getTarget(), options, state);
|
||||
if (failed(targetOpt))
|
||||
return failure();
|
||||
|
||||
auto sourceOpt = getBufferOrValue(rewriter, memCopyOp.getSource(), options, state);
|
||||
if (failed(sourceOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
|
||||
memCopyOp,
|
||||
targetOpt->getType(),
|
||||
*targetOpt,
|
||||
*sourceOpt,
|
||||
memCopyOp.getTargetOffsetAttr(),
|
||||
memCopyOp.getSourceOffsetAttr(),
|
||||
memCopyOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -626,6 +657,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
|
||||
|
||||
@@ -52,9 +52,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
}
|
||||
@@ -62,9 +63,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
@@ -41,30 +41,6 @@ struct DenseSubviewKeyInfo {
|
||||
|
||||
} // namespace
|
||||
|
||||
Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
@@ -177,48 +153,4 @@ FailureOr<DenseElementsAttr> foldDenseSourceToType(ModuleOp moduleOp, Value sour
|
||||
return *denseAttr;
|
||||
}
|
||||
|
||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(info.offsets.size());
|
||||
for (OpFoldResult offset : info.offsets) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
staticOffsets.push_back(*staticOffset);
|
||||
}
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -6,23 +6,12 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||
mlir::Location loc,
|
||||
mlir::MemRefType globalType,
|
||||
@@ -39,9 +28,4 @@ llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp modu
|
||||
llvm::FailureOr<mlir::DenseElementsAttr>
|
||||
foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -67,12 +68,6 @@ static bool isCodegenAddressableValue(Value value) {
|
||||
}
|
||||
|
||||
static bool isConstantGlobalView(Value value) {
|
||||
auto allStaticSubviewParts = [](memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
|
||||
while (true) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
if (!defOp)
|
||||
@@ -84,7 +79,7 @@ static bool isConstantGlobalView(Value value) {
|
||||
&& isa<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
}
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!allStaticSubviewParts(subview))
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return false;
|
||||
value = subview.getSource();
|
||||
continue;
|
||||
|
||||
Reference in New Issue
Block a user