add pim.vmm verifier and fix vmm lowering
reuse code for subviews
This commit is contained in:
@@ -3,6 +3,7 @@ add_pim_library(OMPimCommon
|
|||||||
IR/CoreBlockUtils.cpp
|
IR/CoreBlockUtils.cpp
|
||||||
IR/EntryPointUtils.cpp
|
IR/EntryPointUtils.cpp
|
||||||
IR/ShapeUtils.cpp
|
IR/ShapeUtils.cpp
|
||||||
|
IR/SubviewUtils.cpp
|
||||||
IR/WeightUtils.cpp
|
IR/WeightUtils.cpp
|
||||||
Support/DebugDump.cpp
|
Support/DebugDump.cpp
|
||||||
Support/Diagnostics.cpp
|
Support/Diagnostics.cpp
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
Value stripMemRefCasts(Value value) {
|
||||||
|
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||||
|
value = castOp.getSource();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value stripMemRefViewOps(Value value) {
|
||||||
|
while (true) {
|
||||||
|
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
||||||
|
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||||
|
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||||
|
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||||
|
value = stripMemRefViewOps(value);
|
||||||
|
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||||
|
if (!subviewOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
StaticSubviewInfo info;
|
||||||
|
info.source = source;
|
||||||
|
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||||
|
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||||
|
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||||
|
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
|
auto staticSize = getConstantIntValue(size);
|
||||||
|
if (!staticSize)
|
||||||
|
return failure();
|
||||||
|
info.sizes.push_back(*staticSize);
|
||||||
|
}
|
||||||
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto staticStride = getConstantIntValue(stride);
|
||||||
|
if (!staticStride)
|
||||||
|
return failure();
|
||||||
|
info.strides.push_back(*staticStride);
|
||||||
|
}
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||||
|
SmallVector<int64_t> staticOffsets;
|
||||||
|
staticOffsets.reserve(info.offsets.size());
|
||||||
|
for (OpFoldResult offset : info.offsets) {
|
||||||
|
auto staticOffset = getConstantIntValue(offset);
|
||||||
|
if (!staticOffset)
|
||||||
|
return failure();
|
||||||
|
staticOffsets.push_back(*staticOffset);
|
||||||
|
}
|
||||||
|
return staticOffsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
struct StaticSubviewInfo {
|
||||||
|
mlir::Value source;
|
||||||
|
llvm::SmallVector<int64_t> sourceShape;
|
||||||
|
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||||
|
llvm::SmallVector<int64_t> sizes;
|
||||||
|
llvm::SmallVector<int64_t> strides;
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||||
|
|
||||||
|
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||||
|
|
||||||
|
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
||||||
|
|
||||||
|
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||||
|
|
||||||
|
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||||
|
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -24,8 +24,8 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
|
||||||
#include "Common/IR/CompactAsmUtils.hpp"
|
#include "Common/IR/CompactAsmUtils.hpp"
|
||||||
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
@@ -206,9 +206,7 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
|
|||||||
return iter->second.address + resolvedAddress->byteOffset;
|
return iter->second.address + resolvedAddress->byteOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimAcceleratorMemory::reportHost() {
|
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
|
||||||
hostReportRow = hostMem.getReportRow();
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||||
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
||||||
@@ -810,7 +808,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
||||||
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
||||||
else {
|
else {
|
||||||
op.emitError("Unsupported codegen for this operation");
|
InFlightDiagnostic diag = op.emitError()
|
||||||
|
<< "unsupported codegen for op '" << op.getName().getStringRef() << "'";
|
||||||
|
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
|
||||||
|
diag << " inside pim.core " << coreOp.getCoreId();
|
||||||
|
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
|
||||||
|
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
processedOperations++;
|
processedOperations++;
|
||||||
@@ -935,9 +938,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||||
if (auto err = emitCore(coreOp, false))
|
if (auto err = emitCore(coreOp, false))
|
||||||
return err;
|
return err;
|
||||||
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
memory.recordCoreReport(
|
||||||
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
|
emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||||
.getReportRow());
|
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId()))).getReportRow());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,8 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
@@ -30,19 +32,6 @@ struct DenseWeightView {
|
|||||||
int64_t offset = 0;
|
int64_t offset = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
|
|
||||||
SmallVector<int64_t> strides(shape.size(), 1);
|
|
||||||
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
|
|
||||||
strides[index] = strides[index + 1] * shape[index + 1];
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool allStaticSubviewParts(memref::SubViewOp subview) {
|
|
||||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
|
||||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
|
||||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||||
SmallVector<memref::SubViewOp> subviews;
|
SmallVector<memref::SubViewOp> subviews;
|
||||||
mlir::Value current = weight;
|
mlir::Value current = weight;
|
||||||
@@ -55,7 +44,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
|||||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||||
break;
|
break;
|
||||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||||
if (!allStaticSubviewParts(subview))
|
if (!hasAllStaticSubviewParts(subview))
|
||||||
return failure();
|
return failure();
|
||||||
subviews.push_back(subview);
|
subviews.push_back(subview);
|
||||||
current = subview.getSource();
|
current = subview.getSource();
|
||||||
@@ -79,7 +68,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
|||||||
DenseWeightView view;
|
DenseWeightView view;
|
||||||
view.denseAttr = denseAttr;
|
view.denseAttr = denseAttr;
|
||||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||||
view.strides = computeRowMajorStridesForShape(view.shape);
|
view.strides = computeRowMajorStrides(view.shape);
|
||||||
|
|
||||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||||
SmallVector<int64_t> nextStrides;
|
SmallVector<int64_t> nextStrides;
|
||||||
|
|||||||
@@ -4,9 +4,9 @@
|
|||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/ADT/APFloat.h"
|
#include "llvm/ADT/APFloat.h"
|
||||||
#include "llvm/ADT/APInt.h"
|
#include "llvm/ADT/APInt.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@@ -47,8 +47,8 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
|||||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createPoolFillElement(
|
static Value
|
||||||
ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||||
if (!useMinimumValue)
|
if (!useMinimumValue)
|
||||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||||
|
|
||||||
@@ -65,8 +65,10 @@ static Value createPoolFillElement(
|
|||||||
llvm_unreachable("unsupported pool element type");
|
llvm_unreachable("unsupported pool element type");
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createPoolFillTensor(
|
static Value createPoolFillTensor(ConversionPatternRewriter& rewriter,
|
||||||
ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) {
|
Location loc,
|
||||||
|
RankedTensorType tensorType,
|
||||||
|
bool useMinimumValue) {
|
||||||
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
|
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
|
||||||
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
|
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
|
||||||
}
|
}
|
||||||
@@ -90,10 +92,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
|||||||
inputType.getDimSize(3) + padLeft + padRight},
|
inputType.getDimSize(3) + padLeft + padRight},
|
||||||
inputType.getElementType(),
|
inputType.getElementType(),
|
||||||
inputType.getEncoding());
|
inputType.getEncoding());
|
||||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> lowPads = {
|
||||||
rewriter.getIndexAttr(0),
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)};
|
||||||
rewriter.getIndexAttr(padTop),
|
|
||||||
rewriter.getIndexAttr(padLeft)};
|
|
||||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(0),
|
rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(padBottom),
|
rewriter.getIndexAttr(padBottom),
|
||||||
@@ -104,8 +104,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
|||||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
padOp.getRegion().push_back(padBlock);
|
padOp.getRegion().push_back(padBlock);
|
||||||
rewriter.setInsertionPointToStart(padBlock);
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
Value padValue = createPoolFillElement(
|
Value padValue =
|
||||||
rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||||
tensor::YieldOp::create(rewriter, loc, padValue);
|
tensor::YieldOp::create(rewriter, loc, padValue);
|
||||||
rewriter.setInsertionPointAfter(padOp);
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
return padOp.getResult();
|
return padOp.getResult();
|
||||||
@@ -279,7 +279,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
||||||
Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
Value paddedInput =
|
||||||
|
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||||
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||||
|
|
||||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
@@ -307,8 +308,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||||
Value reducedWindow = createPoolFillTensor(
|
Value reducedWindow =
|
||||||
rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||||
|
|
||||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||||
Value paddedInH = windowBaseH;
|
Value paddedInH = windowBaseH;
|
||||||
@@ -324,18 +325,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets = {batchIndex,
|
SmallVector<OpFoldResult> offsets = {
|
||||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
|
||||||
paddedInH,
|
|
||||||
paddedInW};
|
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(tileChannels),
|
rewriter.getIndexAttr(tileChannels),
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1)};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> strides = {
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1)};
|
|
||||||
Value windowValue =
|
Value windowValue =
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||||
@@ -344,36 +341,28 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||||
SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> scaleOffsets = {
|
||||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||||
outHeightIndex,
|
|
||||||
outWidthIndex};
|
|
||||||
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(tileChannels),
|
rewriter.getIndexAttr(tileChannels),
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1)};
|
||||||
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> scaleStrides = {
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1)};
|
|
||||||
Value scaleSlice = tensor::ExtractSliceOp::create(
|
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||||
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||||
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
||||||
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpFoldResult> outputOffsets = {batchIndex,
|
SmallVector<OpFoldResult> outputOffsets = {
|
||||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||||
outHeightIndex,
|
|
||||||
outWidthIndex};
|
|
||||||
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(tileChannels),
|
rewriter.getIndexAttr(tileChannels),
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(1)};
|
rewriter.getIndexAttr(1)};
|
||||||
SmallVector<OpFoldResult> outputStrides = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> outputStrides = {
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
rewriter.getIndexAttr(1),
|
|
||||||
rewriter.getIndexAttr(1)};
|
|
||||||
updatedOutput = tensor::InsertSliceOp::create(
|
updatedOutput = tensor::InsertSliceOp::create(
|
||||||
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,12 +9,14 @@
|
|||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@@ -147,6 +149,73 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
|||||||
rewriter.replaceOp(extractRowsOp, replacements);
|
rewriter.replaceOp(extractRowsOp, replacements);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||||
|
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||||
|
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
|
||||||
|
|
||||||
|
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
||||||
|
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
|
||||||
|
continue;
|
||||||
|
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
|
||||||
|
return globalOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string nameStem;
|
||||||
|
llvm::raw_string_ostream nameStream(nameStem);
|
||||||
|
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
|
||||||
|
nameStream.flush();
|
||||||
|
|
||||||
|
std::string symbolName = nameStem;
|
||||||
|
unsigned suffix = 0;
|
||||||
|
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
|
||||||
|
symbolName = (nameStem + "_" + Twine(suffix++)).str();
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||||
|
return memref::GlobalOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rewriter.getStringAttr(symbolName),
|
||||||
|
rewriter.getStringAttr("private"),
|
||||||
|
TypeAttr::get(memRefType),
|
||||||
|
zeroAttr,
|
||||||
|
rewriter.getUnitAttr(),
|
||||||
|
IntegerAttr {});
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||||
|
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||||
|
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||||
|
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||||
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||||
|
|
||||||
|
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
||||||
|
return PimMemCopyHostToDevBatchOp::create(
|
||||||
|
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
|
||||||
|
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
||||||
|
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||||
|
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||||
|
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||||
|
assert(shape[1] <= static_cast<int64_t>(crossbarSize) && "vector width must fit in one crossbar");
|
||||||
|
|
||||||
|
if (shape[1] == static_cast<int64_t>(crossbarSize))
|
||||||
|
return vector;
|
||||||
|
|
||||||
|
auto paddedType = RankedTensorType::get(
|
||||||
|
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||||
|
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
|
||||||
|
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||||
|
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||||
|
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||||
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||||
@@ -426,54 +495,35 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
|
||||||
auto* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return;
|
|
||||||
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
|
|
||||||
if (!dpsDefiningOp)
|
|
||||||
return;
|
|
||||||
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
|
||||||
if (!tiedOperand)
|
|
||||||
return;
|
|
||||||
Value tiedValue = tiedOperand->get();
|
|
||||||
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
|
|
||||||
tiedValue.setType(newType);
|
|
||||||
self(tiedValue, newType, self);
|
|
||||||
};
|
|
||||||
|
|
||||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||||
auto outTensorOperand = vmmOp.getOutputBuffer();
|
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||||
auto resultTensor = vmmOp.getOutput();
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
auto outShape = getTensorShape(outTensorOperand);
|
assert(isHVectorShape(outputShape) && "expected a horizontal vector output");
|
||||||
assert(isHVectorShape(outShape));
|
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||||
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
|
||||||
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
|
||||||
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
|
||||||
if (outTensorOperand == vmmOp.getInput()) {
|
|
||||||
rewriter.setInsertionPoint(vmmOp);
|
|
||||||
auto newOutputBuffer =
|
|
||||||
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
|
||||||
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
|
||||||
outTensorOperand.setType(newType);
|
|
||||||
}
|
|
||||||
resultTensor.setType(newType);
|
|
||||||
|
|
||||||
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
rewriter.setInsertionPoint(vmmOp);
|
||||||
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
|
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
||||||
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
|
auto paddedOutputType = RankedTensorType::get(
|
||||||
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
|
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
||||||
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
|
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
||||||
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
? vmmOp.getOutputBuffer()
|
||||||
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
: createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult();
|
||||||
|
vmmOp.getInputMutable().assign(paddedInput);
|
||||||
|
vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer);
|
||||||
|
|
||||||
|
vmmOp.getOutput().setType(paddedOutputType);
|
||||||
|
|
||||||
|
if (outputShape[1] == static_cast<int64_t>(crossbarSize))
|
||||||
|
return;
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
rewriter.setInsertionPointAfter(vmmOp);
|
rewriter.setInsertionPointAfter(vmmOp);
|
||||||
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
auto sliceOp =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides);
|
||||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||||
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -389,6 +389,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
|||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||||
}];
|
}];
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -77,6 +78,22 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
||||||
|
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||||
|
if (weightIndex >= coreOp.getWeights().size())
|
||||||
|
return failure();
|
||||||
|
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||||
|
if (weightIndex >= coreBatchOp.getWeights().size())
|
||||||
|
return failure();
|
||||||
|
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult PimSendTensorOp::verify() {
|
LogicalResult PimSendTensorOp::verify() {
|
||||||
@@ -104,6 +121,47 @@ LogicalResult PimReceiveTensorBatchOp::verify() {
|
|||||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult PimVMMOp::verify() {
|
||||||
|
if (failed(verifyCompatibleShapedTypes(
|
||||||
|
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
||||||
|
if (failed(matrixShapeOpt))
|
||||||
|
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
||||||
|
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||||
|
|
||||||
|
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||||
|
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||||
|
if (!vectorType || !outputType)
|
||||||
|
return emitError("input and output must be shaped types");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
||||||
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
|
|
||||||
|
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||||
|
return emitError("matrix, vector and output must have rank 2");
|
||||||
|
|
||||||
|
int64_t N = matrixShape[0];
|
||||||
|
int64_t M = matrixShape[1];
|
||||||
|
if (N <= 0 || M <= 0)
|
||||||
|
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||||
|
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
|
||||||
|
return emitError("matrix dimensions must fit in one crossbar");
|
||||||
|
|
||||||
|
int64_t vector1 = vectorShape[0];
|
||||||
|
int64_t vectorWidth = vectorShape[1];
|
||||||
|
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
|
||||||
|
return emitError("vector shape must be (1, crossbar-size)");
|
||||||
|
|
||||||
|
int64_t output1 = outputShape[0];
|
||||||
|
int64_t outputWidth = outputShape[1];
|
||||||
|
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
|
||||||
|
return emitError("output shape must be (1, crossbar-size)");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult PimConcatOp::verify() {
|
LogicalResult PimConcatOp::verify() {
|
||||||
if (getInputs().empty())
|
if (getInputs().empty())
|
||||||
return emitError("requires at least one input");
|
return emitError("requires at least one input");
|
||||||
|
|||||||
@@ -105,6 +105,37 @@ struct MemCopyDevToHostOpInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInterface, PimMemCopyOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto memCopyOp = cast<PimMemCopyOp>(op);
|
||||||
|
|
||||||
|
auto targetOpt = getBufferOrValue(rewriter, memCopyOp.getTarget(), options, state);
|
||||||
|
if (failed(targetOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceOpt = getBufferOrValue(rewriter, memCopyOp.getSource(), options, state);
|
||||||
|
if (failed(sourceOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
|
||||||
|
memCopyOp,
|
||||||
|
targetOpt->getType(),
|
||||||
|
*targetOpt,
|
||||||
|
*sourceOpt,
|
||||||
|
memCopyOp.getTargetOffsetAttr(),
|
||||||
|
memCopyOp.getSourceOffsetAttr(),
|
||||||
|
memCopyOp.getSizeAttr());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
@@ -626,6 +657,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
|||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
|
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
|
||||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||||
|
|
||||||
|
|||||||
@@ -52,9 +52,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
|||||||
printer << " ";
|
printer << " ";
|
||||||
printer.printOperand(op.getInput());
|
printer.printOperand(op.getInput());
|
||||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||||
printer.printOptionalAttrDict(
|
printer.printOptionalAttrDict(op->getAttrs(),
|
||||||
op->getAttrs(),
|
{op.getChannelIdsAttrName().getValue(),
|
||||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
op.getSourceCoreIdsAttrName().getValue(),
|
||||||
|
op.getTargetCoreIdsAttrName().getValue()});
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
printer.printType(op.getInput().getType());
|
printer.printType(op.getInput().getType());
|
||||||
}
|
}
|
||||||
@@ -62,9 +63,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
|||||||
template <typename TensorReceiveOpTy>
|
template <typename TensorReceiveOpTy>
|
||||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||||
printer.printOptionalAttrDict(
|
printer.printOptionalAttrDict(op->getAttrs(),
|
||||||
op->getAttrs(),
|
{op.getChannelIdsAttrName().getValue(),
|
||||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
op.getSourceCoreIdsAttrName().getValue(),
|
||||||
|
op.getTargetCoreIdsAttrName().getValue()});
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
printer.printType(op.getOutput().getType());
|
printer.printType(op.getOutput().getType());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,30 +41,6 @@ struct DenseSubviewKeyInfo {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Value stripMemRefCasts(Value value) {
|
|
||||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
|
||||||
value = castOp.getSource();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
Value stripMemRefViewOps(Value value) {
|
|
||||||
while (true) {
|
|
||||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
|
||||||
value = castOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
|
||||||
value = collapseOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
|
||||||
value = expandOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||||
Location loc,
|
Location loc,
|
||||||
MemRefType globalType,
|
MemRefType globalType,
|
||||||
@@ -177,48 +153,4 @@ FailureOr<DenseElementsAttr> foldDenseSourceToType(ModuleOp moduleOp, Value sour
|
|||||||
return *denseAttr;
|
return *denseAttr;
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
|
||||||
value = stripMemRefViewOps(value);
|
|
||||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
|
||||||
if (!subviewOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
|
||||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
|
||||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
|
||||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
StaticSubviewInfo info;
|
|
||||||
info.source = source;
|
|
||||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
|
||||||
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
|
||||||
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
|
||||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
||||||
auto staticSize = getConstantIntValue(size);
|
|
||||||
if (!staticSize)
|
|
||||||
return failure();
|
|
||||||
info.sizes.push_back(*staticSize);
|
|
||||||
}
|
|
||||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
||||||
auto staticStride = getConstantIntValue(stride);
|
|
||||||
if (!staticStride)
|
|
||||||
return failure();
|
|
||||||
info.strides.push_back(*staticStride);
|
|
||||||
}
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
|
||||||
SmallVector<int64_t> staticOffsets;
|
|
||||||
staticOffsets.reserve(info.offsets.size());
|
|
||||||
for (OpFoldResult offset : info.offsets) {
|
|
||||||
auto staticOffset = getConstantIntValue(offset);
|
|
||||||
if (!staticOffset)
|
|
||||||
return failure();
|
|
||||||
staticOffsets.push_back(*staticOffset);
|
|
||||||
}
|
|
||||||
return staticOffsets;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -6,23 +6,12 @@
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
struct StaticSubviewInfo {
|
|
||||||
mlir::Value source;
|
|
||||||
llvm::SmallVector<int64_t> sourceShape;
|
|
||||||
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
|
||||||
llvm::SmallVector<int64_t> sizes;
|
|
||||||
llvm::SmallVector<int64_t> strides;
|
|
||||||
};
|
|
||||||
|
|
||||||
mlir::Value stripMemRefCasts(mlir::Value value);
|
|
||||||
|
|
||||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
|
||||||
|
|
||||||
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
mlir::MemRefType globalType,
|
mlir::MemRefType globalType,
|
||||||
@@ -39,9 +28,4 @@ llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp modu
|
|||||||
llvm::FailureOr<mlir::DenseElementsAttr>
|
llvm::FailureOr<mlir::DenseElementsAttr>
|
||||||
foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType);
|
foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType);
|
||||||
|
|
||||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
|
||||||
|
|
||||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
|
||||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -67,12 +68,6 @@ static bool isCodegenAddressableValue(Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool isConstantGlobalView(Value value) {
|
static bool isConstantGlobalView(Value value) {
|
||||||
auto allStaticSubviewParts = [](memref::SubViewOp subview) {
|
|
||||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
|
||||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
|
||||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
|
||||||
};
|
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
Operation* defOp = value.getDefiningOp();
|
Operation* defOp = value.getDefiningOp();
|
||||||
if (!defOp)
|
if (!defOp)
|
||||||
@@ -84,7 +79,7 @@ static bool isConstantGlobalView(Value value) {
|
|||||||
&& isa<DenseElementsAttr>(*globalOp.getInitialValue());
|
&& isa<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||||
}
|
}
|
||||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||||
if (!allStaticSubviewParts(subview))
|
if (!hasAllStaticSubviewParts(subview))
|
||||||
return false;
|
return false;
|
||||||
value = subview.getSource();
|
value = subview.getSource();
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
Reference in New Issue
Block a user