Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a7dee5b840 | |||
| 77fe293062 | |||
| 525792e545 | |||
| eade488d13 | |||
| 30ee9640d4 | |||
| 368e340a40 | |||
| e866ec6f87 |
@@ -33,7 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPim) {
|
||||
pm.addPass(createMergeComputeNodePass());
|
||||
pm.addPass(createMergeComputeNodesPass());
|
||||
pm.addPass(createSpatialToPimPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||
@@ -46,9 +46,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createPimConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim constants folded"));
|
||||
pm.addPass(createPimMaterializeConstantsPass());
|
||||
pm.addPass(createPimHostConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim host constants folded"));
|
||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
@@ -22,8 +23,8 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -51,7 +52,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
void mergeSingleChildCompute(func::FuncOp funcOp);
|
||||
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -148,10 +149,10 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
encapsulateGlobalInstruction(*entryFunc);
|
||||
mergeSingleChildCompute(*entryFunc);
|
||||
mergeTriviallyConnectedComputes(*entryFunc);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "spatial");
|
||||
dumpModule(moduleOp, "spatial0");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -230,66 +231,61 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::mergeSingleChildCompute(func::FuncOp funcOp) {
|
||||
llvm::SmallVector<spatial::SpatWeightedCompute> computeSingleChild;
|
||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
|
||||
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||
if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) {
|
||||
if (compute->hasOneUse()) {
|
||||
auto user = *compute->getUsers().begin();
|
||||
if (user->getNumOperands() == 1)
|
||||
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
||||
computeSingleChild.push_back(compute);
|
||||
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
||||
trivialComputes.push_back(compute);
|
||||
}
|
||||
|
||||
IRMapping mapper;
|
||||
while (!computeSingleChild.empty()) {
|
||||
auto compute = computeSingleChild.front();
|
||||
auto child = dyn_cast_if_present<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
assert(child && "Child required!");
|
||||
while (!trivialComputes.empty()) {
|
||||
auto compute = trivialComputes.front();
|
||||
|
||||
if (compute.use_empty()) {
|
||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||
trivialComputes.pop_back();
|
||||
continue;
|
||||
}
|
||||
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
|
||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) compute.getWeights().size(), (int) compute.getInputs().size()});
|
||||
llvm::dbgs() << "After Creation\n";
|
||||
newCompute.dump();
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||
|
||||
IRMapping mapper;
|
||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||
llvm::dbgs() << "After Clone\n";
|
||||
newCompute.dump();
|
||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||
newTerminator->erase();
|
||||
llvm::dbgs() << "After terminator\n";
|
||||
newCompute.dump();
|
||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||
|
||||
for (auto& op : child.getBody().front())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
child.replaceAllUsesWith(newCompute);
|
||||
assert(child->getUses().empty() && "It's not obvius");
|
||||
llvm::dbgs() << "Node\n";
|
||||
newCompute.dump();
|
||||
toErase.insert(child);
|
||||
|
||||
llvm::dbgs() << "Parent\n";
|
||||
compute.dump();
|
||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||
trivialComputes.pop_back();
|
||||
toErase.insert(compute);
|
||||
|
||||
llvm::dbgs() << "Child\n";
|
||||
child.dump();
|
||||
|
||||
child.erase();
|
||||
compute.erase();
|
||||
|
||||
if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) {
|
||||
if (newCompute->hasOneUse()) {
|
||||
auto user = *newCompute->getUsers().begin();
|
||||
if (user->getNumOperands() == 1)
|
||||
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
||||
computeSingleChild.push_back(newCompute);
|
||||
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
||||
trivialComputes.push_back(newCompute);
|
||||
}
|
||||
std::swap(computeSingleChild.front(), computeSingleChild.back());
|
||||
computeSingleChild.pop_back();
|
||||
}
|
||||
|
||||
for (auto compute : toErase) {
|
||||
compute.getResult(0).dropAllUses();
|
||||
compute.erase();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -24,122 +28,150 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
// We need to understand what is group
|
||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
return bias;
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType());
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedBiasType,
|
||||
bias,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
}
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
|
||||
static Value createPaddedRows(Value tensorValue,
|
||||
RankedTensorType tensorType,
|
||||
int64_t paddedRows,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (tensorType.getDimSize(0) == paddedRows)
|
||||
return tensorValue;
|
||||
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 2; i++)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(
|
||||
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
|
||||
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
Value wTrans,
|
||||
RankedTensorType wType,
|
||||
int64_t numChannelsIn,
|
||||
int64_t numChannelsOut,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t patchSize,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return wTrans;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
auto packedWeightType =
|
||||
RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType());
|
||||
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
|
||||
cast<Attribute>(rewriter.getZeroAttr(wType.getElementType())));
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64(*padsAttr, 0);
|
||||
padWidthBegin = getI64(*padsAttr, 1);
|
||||
padHeightEnd = getI64(*padsAttr, 2);
|
||||
padWidthEnd = getI64(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
for (int64_t copyId = 0; copyId < packFactor; copyId++) {
|
||||
for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) {
|
||||
for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) {
|
||||
for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) {
|
||||
for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) {
|
||||
const int64_t sourceFlatIndex =
|
||||
(((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW;
|
||||
const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW;
|
||||
const int64_t targetRow = copyId * patchSize + patchIndex;
|
||||
const int64_t targetCol = copyId * numChannelsOut + outChannel;
|
||||
packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// Gemm output: [numPatches, cOut]
|
||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
||||
}
|
||||
|
||||
static Value buildPackedBias(bool hasBias,
|
||||
Value gemmBias,
|
||||
Value biasMatrix,
|
||||
DenseElementsAttr biasDenseAttr,
|
||||
RankedTensorType outType,
|
||||
int64_t numChannelsOut,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (!hasBias)
|
||||
return gemmBias;
|
||||
|
||||
if (packFactor == 1)
|
||||
return biasMatrix;
|
||||
|
||||
SmallVector<Attribute> sourceValues(biasDenseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> packedValues;
|
||||
packedValues.reserve(packFactor * numChannelsOut);
|
||||
for (int64_t copyId = 0; copyId < packFactor; copyId++)
|
||||
packedValues.append(sourceValues.begin(), sourceValues.end());
|
||||
|
||||
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
|
||||
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
}
|
||||
|
||||
static Value createIm2colCompute(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType rowType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto elemType = xType.getElementType();
|
||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
w,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC;
|
||||
if (hasB)
|
||||
gemmC = b;
|
||||
else
|
||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
constexpr size_t numInputs = 1;
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
@@ -226,23 +258,104 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
});
|
||||
return im2colComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
// Gemm: A @ B + C = im2col @ W^T + b
|
||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutType,
|
||||
im2colComputeOp.getResult(0),
|
||||
wTrans,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
Value gemmOut = gemmOp.getY();
|
||||
static Value createPackedIm2colRows(Value im2col,
|
||||
RankedTensorType im2colType,
|
||||
Type elemType,
|
||||
int64_t numPatches,
|
||||
int64_t patchSize,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return im2col;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
|
||||
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
|
||||
});
|
||||
return packedComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createUnpackedOutput(Value packedOutput,
|
||||
RankedTensorType gemmOutType,
|
||||
RankedTensorType outType,
|
||||
int64_t numPatches,
|
||||
int64_t numChannelsOut,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return packedOutput;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedOutputArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
|
||||
Value unpackedOutput = paddedOutput;
|
||||
if (paddedNumPatches != numPatches) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
unpackedOutput =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
|
||||
});
|
||||
return unpackComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createCollectedConvOutput(Value gemmOut,
|
||||
Type convType,
|
||||
RankedTensorType nhwcType,
|
||||
RankedTensorType outType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto collectComputeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) {
|
||||
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
|
||||
Value gemmOutArg = gemmOutArgs.front();
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
@@ -256,11 +369,225 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
{3}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
});
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||
|
||||
// We need to understand what is group
|
||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// Gemm output: [numPatches, cOut]
|
||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||
|
||||
auto elemType = xType.getElementType();
|
||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getDenseConstantAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
w,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
Value biasMatrix;
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmC = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
|
||||
const int64_t effectiveMaxParallelPixels =
|
||||
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
||||
|
||||
Value im2col = createIm2colCompute(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
rewriter,
|
||||
loc);
|
||||
|
||||
Value gemmOut;
|
||||
if (effectiveMaxParallelPixels == 1) {
|
||||
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
|
||||
gemmOut = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutType,
|
||||
im2col,
|
||||
wTrans,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
}
|
||||
else {
|
||||
// Keep the standard im2col view of convolution:
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// but repack several old rows into one new row so we use the available crossbar size better.
|
||||
//
|
||||
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
|
||||
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||
// B_packed: [N * patchSize, N * cOut]
|
||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||
auto packedOutType =
|
||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
|
||||
Value packedA = createPackedIm2colRows(
|
||||
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
|
||||
Value packedB = buildPackedWeight(wDenseAttr,
|
||||
wTrans,
|
||||
wType,
|
||||
numChannelsIn,
|
||||
numChannelsOut,
|
||||
wHeight,
|
||||
wWidth,
|
||||
patchSize,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
Value packedC = buildPackedBias(
|
||||
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
Value packedOut = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
packedOutType,
|
||||
packedA,
|
||||
packedB,
|
||||
packedC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
gemmOut = createUnpackedOutput(
|
||||
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,29 @@
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
|
||||
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
|
||||
assert(attr && "required precomputed channel attr is missing");
|
||||
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||
/*
|
||||
EXAMPLE RUN:
|
||||
@@ -54,6 +66,45 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
|
||||
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
|
||||
@@ -2,11 +2,16 @@
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
|
||||
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
@@ -21,6 +26,21 @@ namespace onnx_mlir {
|
||||
*/
|
||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
|
||||
@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed source core id">;
|
||||
|
||||
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed target core id">;
|
||||
|
||||
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def onnxToPimTranspose : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
@@ -69,4 +80,18 @@ def spatToPimVSoftmax : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatChannelSendToPimSend : Pat<
|
||||
(SpatChannelSendOp $channel, $input),
|
||||
(PimSendOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
def spatChannelReceiveToPimReceive : Pat<
|
||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
@@ -10,8 +10,10 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
@@ -68,6 +70,8 @@ private:
|
||||
bool useBroadcastOp,
|
||||
IRRewriter& rewriter);
|
||||
void markOpToRemove(Operation* op);
|
||||
void annotateChannelCoreIds(func::FuncOp funcOp);
|
||||
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
@@ -175,6 +179,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
runOnComputeOp(computeOp, rewriter);
|
||||
}
|
||||
|
||||
annotateChannelCoreIds(funcOp);
|
||||
lowerBroadcastChannelOps(funcOp, rewriter);
|
||||
|
||||
RewritePatternSet channelPatterns(ctx);
|
||||
populateWithGenerated(channelPatterns);
|
||||
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnOpOperands(returnOp, rewriter);
|
||||
|
||||
@@ -623,6 +637,94 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
||||
markOpToRemove(channelNewOp);
|
||||
|
||||
if (channelNewOp->use_empty())
|
||||
return;
|
||||
|
||||
spatial::SpatChannelSendOp sendOp;
|
||||
spatial::SpatChannelReceiveOp receiveOp;
|
||||
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
||||
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
if (auto op = dyn_cast<spatial::SpatChannelSendOp>(user)) {
|
||||
sendOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelReceiveOp>(user)) {
|
||||
receiveOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelBroadcastSendOp>(user)) {
|
||||
broadcastSendOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
|
||||
continue;
|
||||
}
|
||||
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
|
||||
}
|
||||
|
||||
if (broadcastSendOp) {
|
||||
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(broadcastSendOp->getParentOp()).getCoreIdAttr();
|
||||
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!sendOp || !receiveOp)
|
||||
llvm_unreachable("spat.channel_new must connect exactly one send and one receive");
|
||||
|
||||
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
||||
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
||||
channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr);
|
||||
});
|
||||
}
|
||||
|
||||
void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatChannelBroadcastSendOp> broadcastSendOps;
|
||||
funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); });
|
||||
|
||||
for (auto sendOp : broadcastSendOps) {
|
||||
auto channelNewOp = cast<spatial::SpatChannelNewOp>(sendOp.getChannel().getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
||||
|
||||
rewriter.setInsertionPoint(sendOp);
|
||||
bool foundReceiver = false;
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user);
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
foundReceiver = true;
|
||||
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
||||
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
||||
}
|
||||
|
||||
if (!foundReceiver)
|
||||
llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive");
|
||||
|
||||
rewriter.eraseOp(sendOp);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelBroadcastReceiveOp> broadcastReceiveOps;
|
||||
funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); });
|
||||
|
||||
for (auto receiveOp : broadcastReceiveOps) {
|
||||
rewriter.setInsertionPoint(receiveOp);
|
||||
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
|
||||
Value receivedValue =
|
||||
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(receiveOp, receivedValue);
|
||||
}
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
for (auto it : llvm::enumerate(originalOperands)) {
|
||||
|
||||
@@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveOp>(op);
|
||||
|
||||
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
|
||||
@@ -3,11 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
add_pim_library(SpatialOps
|
||||
SpatialOps.cpp
|
||||
Transforms/SpatialBufferizableOpInterface.cpp
|
||||
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
|
||||
DCPGraph/Graph.cpp
|
||||
DCPGraph/Task.cpp
|
||||
DCPGraph/DCPAnalysis.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
-1
@@ -8,7 +8,6 @@
|
||||
|
||||
#include <iterator>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
+1
-1
@@ -6,7 +6,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
||||
+3
-3
@@ -7,11 +7,11 @@
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "Task.hpp"
|
||||
#include "Uniqueworklist.hpp"
|
||||
#include "UniqueWorklist.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
|
||||
@@ -420,7 +420,7 @@ void GraphDCP::to_dot() {
|
||||
std::string outputDir = onnx_mlir::getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
std::string graphDir = outputDir + "/DCPGraph";
|
||||
std::string graphDir = outputDir + "/dcp_graph";
|
||||
onnx_mlir::createDirectory(graphDir);
|
||||
std::fstream file(graphDir + "/graph_" + std::to_string(index++) + ".dot", std::ios::out);
|
||||
file << "digraph G {\n";
|
||||
+1
-1
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "Graph.hpp"
|
||||
#include "Task.hpp"
|
||||
#include "Uniqueworklist.hpp"
|
||||
#include "UniqueWorklist.hpp"
|
||||
|
||||
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
|
||||
std::optional<Edge_t> oldEdge = std::nullopt;
|
||||
+1
-1
@@ -5,7 +5,7 @@
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||
+1
-1
@@ -9,7 +9,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
using CPU = int;
|
||||
+7
-7
@@ -20,7 +20,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "DCPGraph/DCPAnalysis.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -81,7 +81,7 @@ public:
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||
};
|
||||
|
||||
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
|
||||
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
private:
|
||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
@@ -89,11 +89,11 @@ private:
|
||||
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
||||
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-merge-node-pass"; }
|
||||
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
|
||||
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
|
||||
"execution time";
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ public:
|
||||
computeNodetoRemove.erase();
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "SpatialDCPMerged");
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -346,6 +346,6 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
|
||||
std::unique_ptr<Pass> createMergeComputeNodesPass() { return std::make_unique<MergeComputeNodesPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,532 +0,0 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace bufferization;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
|
||||
auto resultShape = cast<ShapedType>(resultType);
|
||||
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
||||
|
||||
// Alloc an output memref
|
||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||
}
|
||||
|
||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
||||
return memrefValue;
|
||||
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
return pim::PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
contiguousBuffer.getType(),
|
||||
contiguousBuffer,
|
||||
memrefValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||
|
||||
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
++usersIterator;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto secondUser = *usersIterator;
|
||||
++usersIterator;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Operation* otherUser = nullptr;
|
||||
if (firstUser == op)
|
||||
otherUser = secondUser;
|
||||
else if (secondUser == op)
|
||||
otherUser = firstUser;
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return otherUser;
|
||||
}
|
||||
|
||||
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
// This function requires the existence of ChannelNewOp and the other
|
||||
// Receive/Send operation. However, during bufferization, the first of the
|
||||
// Receive/Send operation that is processed gets removed. As such, we need to
|
||||
// "precompute" the coreId needed for the other op, and save it as attribute
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId)
|
||||
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
|
||||
|
||||
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
|
||||
if (failed(notOpUserOpt))
|
||||
return failure();
|
||||
Operation* notOpUser = *notOpUserOpt;
|
||||
|
||||
// Save the coreId for this op into the other op as attribute
|
||||
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
|
||||
|
||||
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
|
||||
}
|
||||
|
||||
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
|
||||
|
||||
// Input tensor to the compute OP are always read into its local memory
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensor to the compute OP are _never_ written into its local memory
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the compute OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
// Bufferize its block
|
||||
|
||||
auto& block = op->getRegion(0).front();
|
||||
|
||||
return bufferizeBlockSignature(&block, rewriter, options, state);
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This can be used for operation that have a single argument, which is a
|
||||
* variadic of tensors, and a single output with the same same shape
|
||||
* Example: VAdd, VSub, VExp
|
||||
*/
|
||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
||||
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
||||
|
||||
// Input tensors to the OP are always read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensors to the OP are _never_ written
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Cast tensor values into memref values
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
// Turn Tensor Operands into Memref Operands
|
||||
SmallVector<Value> memrefOperands;
|
||||
memrefOperands.reserve(op->getNumOperands());
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto memref = getBuffer(rewriter, operand, options, state);
|
||||
if (failed(memref))
|
||||
return failure();
|
||||
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
||||
}
|
||||
|
||||
// TODO: Support addiction with more than 2 operands
|
||||
if (memrefOperands.size() > 2) {
|
||||
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
|
||||
"with 1 or 2 operands, for now.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
memrefOperands.push_back(outputTensor);
|
||||
|
||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
||||
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
||||
|
||||
// Input tensors to the OP are always read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||
|
||||
// Input tensors to the OP are _never_ written
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// In general, no tensor is aliased with any other tensor in the OP
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Cast tensor value into memref value
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
if (failed(memrefOperandOpt))
|
||||
return failure();
|
||||
auto memrefOperand = *memrefOperandOpt;
|
||||
|
||||
// Alloc an output memref
|
||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
Value newValue = ToTy::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
cast<OpTy>(op).getWeightIndexAttr(),
|
||||
memrefOperand,
|
||||
outputTensor)
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
|
||||
|
||||
// Input value is the channel (not read/written, its more of an attribute)
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel receive to pim.recv
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
|
||||
|
||||
// First input is channel (not read/writter) second input is Tensor to send,
|
||||
// which is read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// See above (both non-written)
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
if (failed(srcTensorOpt))
|
||||
return failure();
|
||||
auto srcMemRef = *srcTensorOpt;
|
||||
|
||||
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
|
||||
if (failed(dstCoreId))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
|
||||
op,
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(dstCoreId.value()));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelBroadcastReceiveOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
|
||||
|
||||
// Input value is the channel (not read/written, its more of an attribute)
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
||||
if (precomputedOtherCoreId) {
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
cast<IntegerAttr>(precomputedOtherCoreId))
|
||||
.getOutput();
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
|
||||
if (!sendOp)
|
||||
continue;
|
||||
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
|
||||
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
|
||||
}
|
||||
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
|
||||
return failure();
|
||||
}();
|
||||
if (failed(srcCoreId))
|
||||
return failure();
|
||||
|
||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
||||
.getOutput();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelBroadcastSendOpInterface
|
||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
|
||||
|
||||
// First input is channel (not read/writter) second input is Tensor to send,
|
||||
// which is read
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return opOperand.getOperandNumber() == 2;
|
||||
}
|
||||
|
||||
// See above (both non-written)
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||
|
||||
// See above
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
* Turn the broadcast send into one pim.send per broadcast receiver.
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
if (failed(srcTensorOpt))
|
||||
return failure();
|
||||
auto srcMemRef = *srcTensorOpt;
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto srcType = cast<ShapedType>(srcTensor.getType());
|
||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
||||
|
||||
rewriter.setInsertionPoint(op);
|
||||
bool foundReceiver = false;
|
||||
for (Operation* user : channelNewOp->getUsers()) {
|
||||
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
|
||||
if (!receiveOp)
|
||||
continue;
|
||||
|
||||
foundReceiver = true;
|
||||
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
|
||||
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op->getLoc(),
|
||||
srcMemRef,
|
||||
rewriter.getI32IntegerAttr(sizeInBytes),
|
||||
rewriter.getI32IntegerAttr(dstCoreId));
|
||||
}
|
||||
|
||||
if (!foundReceiver) {
|
||||
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VAddOpInterfaceFromTemplate
|
||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
||||
|
||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
||||
|
||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
||||
|
||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
||||
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
||||
|
||||
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
||||
|
||||
struct ONNXSigmoidInterface
|
||||
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
||||
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
||||
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,13 +1,13 @@
|
||||
add_pim_library(OMPimPasses
|
||||
CountInstructionPass.cpp
|
||||
MessagePass.cpp
|
||||
Pim/ConstantFolding/Common.cpp
|
||||
Pim/ConstantFolding/Patterns/Constant.cpp
|
||||
Pim/ConstantFolding/ConstantFoldingPass.cpp
|
||||
Pim/ConstantFolding/Patterns/Subview.cpp
|
||||
Pim/MaterializeConstantsPass.cpp
|
||||
Pim/VerificationPass.cpp
|
||||
Pim/EmitPimJsonPass.cpp
|
||||
PimCodegen/HostConstantFolding/Common.cpp
|
||||
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
|
||||
PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
|
||||
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
|
||||
PimCodegen/MaterializeHostConstantsPass.cpp
|
||||
PimCodegen/VerificationPass.cpp
|
||||
PimCodegen/EmitPimJsonPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMergeComputeNodePass();
|
||||
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
||||
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
||||
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||
|
||||
|
||||
+4
-4
@@ -11,10 +11,10 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
|
||||
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
||||
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
|
||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override {
|
||||
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
||||
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+6
-4
@@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) {
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
|
||||
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||
|
||||
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
||||
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||
}
|
||||
@@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
||||
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
||||
return std::make_unique<MaterializeHostConstantsPass>();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -18,7 +18,6 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
@@ -67,8 +66,6 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
||||
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
spatial::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
|
||||
pim::registerOpBufferizationInterfaces(registry);
|
||||
}
|
||||
|
||||
@@ -78,9 +75,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
||||
registerPass(createSpatialToGraphvizPass);
|
||||
registerPass(createSpatialToPimPass);
|
||||
registerPass(createPimBufferizationPass);
|
||||
registerPass(createMergeComputeNodePass);
|
||||
registerPass(createPimConstantFoldingPass);
|
||||
registerPass(createPimMaterializeConstantsPass);
|
||||
registerPass(createMergeComputeNodesPass);
|
||||
registerPass(createPimHostConstantFoldingPass);
|
||||
registerPass(createPimMaterializeHostConstantsPass);
|
||||
registerPass(createPimVerificationPass);
|
||||
registerPass(createEmitPimJsonPass);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
Rimuovere la logica di bufferizazione da spatial,
|
||||
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
|
||||
|
||||
AnalisiDCP
|
||||
|
||||
+54
-13
@@ -12,7 +12,16 @@ from gen_network_runner import gen_network_runner
|
||||
from subprocess_utils import run_command_with_reporter
|
||||
|
||||
|
||||
STAGE_COUNT = 6
|
||||
STAGE_TITLES = (
|
||||
"Compile ONNX",
|
||||
"Build Runner",
|
||||
"Generate Inputs",
|
||||
"Run Reference",
|
||||
"Compile PIM",
|
||||
"Run Simulator",
|
||||
"Compare Outputs",
|
||||
)
|
||||
STAGE_COUNT = len(STAGE_TITLES)
|
||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||
|
||||
|
||||
@@ -22,6 +31,8 @@ class ProgressReporter:
|
||||
self.stages_per_model = stages_per_model
|
||||
self.total_steps = max(1, total_models * stages_per_model)
|
||||
self.completed_steps = 0
|
||||
self.passed_models = 0
|
||||
self.failed_models = 0
|
||||
self.current_label = ""
|
||||
self.enabled = True
|
||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||
@@ -36,21 +47,42 @@ class ProgressReporter:
|
||||
return
|
||||
bar_width = 24
|
||||
filled = int(bar_width * self.completed_steps / self.total_steps)
|
||||
counts_text = f"P:{self.passed_models} F:{self.failed_models}"
|
||||
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
|
||||
if len(prefix_text) > self.columns:
|
||||
prefix_text = f"{self.completed_steps}/{self.total_steps}"
|
||||
|
||||
label = f" {self.current_label}" if self.current_label else ""
|
||||
available_label_width = max(0, self.columns - len(prefix_text))
|
||||
label = label[:available_label_width]
|
||||
|
||||
if prefix_text.startswith("["):
|
||||
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
|
||||
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
|
||||
else:
|
||||
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
|
||||
|
||||
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
|
||||
counts = (
|
||||
" "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.GREEN
|
||||
+ f"P:{self.passed_models}"
|
||||
+ Style.RESET_ALL
|
||||
+ " "
|
||||
+ Style.BRIGHT
|
||||
+ Fore.RED
|
||||
+ f"F:{self.failed_models}"
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
model_counter = ""
|
||||
label = ""
|
||||
if self.current_label.startswith("[") and "] " in self.current_label:
|
||||
model_counter, label = self.current_label.split("] ", 1)
|
||||
model_counter = f" {model_counter}]"
|
||||
label = f" {label}"
|
||||
elif self.current_label:
|
||||
label = f" {self.current_label}"
|
||||
|
||||
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
|
||||
label = label[:available_label_width]
|
||||
|
||||
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
|
||||
sys.stdout.flush()
|
||||
|
||||
def log(self, message="", color=None):
|
||||
@@ -70,6 +102,13 @@ class ProgressReporter:
|
||||
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
|
||||
self._render()
|
||||
|
||||
def record_result(self, passed):
|
||||
if passed:
|
||||
self.passed_models += 1
|
||||
else:
|
||||
self.failed_models += 1
|
||||
self._render()
|
||||
|
||||
def suspend(self):
|
||||
self.suspended = True
|
||||
self._clear()
|
||||
@@ -112,13 +151,13 @@ def clean_workspace_artifacts(workspace_dir, model_stem):
|
||||
|
||||
def print_stage(reporter, model_index, model_total, model_name, title):
|
||||
stage_colors = {
|
||||
"Compile ONNX": Fore.BLUE,
|
||||
"Build Runner": Fore.MAGENTA,
|
||||
"Generate Inputs": Fore.YELLOW,
|
||||
"Run Reference": Fore.GREEN,
|
||||
"Compile PIM": Fore.CYAN,
|
||||
"Run Simulator": Fore.MAGENTA,
|
||||
"Compare Outputs": Fore.YELLOW,
|
||||
STAGE_TITLES[0]: Fore.BLUE,
|
||||
STAGE_TITLES[1]: Fore.MAGENTA,
|
||||
STAGE_TITLES[2]: Fore.YELLOW,
|
||||
STAGE_TITLES[3]: Fore.GREEN,
|
||||
STAGE_TITLES[4]: Fore.CYAN,
|
||||
STAGE_TITLES[5]: Fore.MAGENTA,
|
||||
STAGE_TITLES[6]: Fore.YELLOW,
|
||||
}
|
||||
color = stage_colors.get(title, Fore.WHITE)
|
||||
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
|
||||
@@ -284,11 +323,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
|
||||
reporter.resume()
|
||||
reporter.advance()
|
||||
reporter.record_result(passed)
|
||||
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||
return passed
|
||||
except Exception:
|
||||
failed_with_exception = True
|
||||
reporter.record_result(False)
|
||||
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
|
||||
reporter.suspend()
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user