Compare commits

7 Commits

Author SHA1 Message Date
NiccoloN a7dee5b840 fix mergeTriviallyConnectedComputes
Validate Operations / validate-operations (push) Successful in 16m48s
2026-04-14 15:13:45 +02:00
NiccoloN 77fe293062 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-04-14 13:45:10 +02:00
NiccoloN 525792e545 better progress bar in validate_one.py 2026-04-14 13:13:56 +02:00
NiccoloN eade488d13 fix missed failing tests for channels
moderate refactor
2026-04-14 12:26:41 +02:00
NiccoloN 30ee9640d4 remove spatial bufferization logic (didn't make much sense) and move channel lowering to SpatialToPim 2026-04-14 11:55:19 +02:00
NiccoloN 368e340a40 minor refactor 2026-04-14 11:21:11 +02:00
NiccoloN e866ec6f87 convolution uses crossbar size better 2026-04-14 11:06:35 +02:00
34 changed files with 802 additions and 766 deletions
+4 -4
View File
@@ -33,7 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPim) {
pm.addPass(createMergeComputeNodePass());
pm.addPass(createMergeComputeNodesPass());
pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim"));
@@ -46,9 +46,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimConstantFoldingPass());
pm.addPass(createMessagePass("Pim constants folded"));
pm.addPass(createPimMaterializeConstantsPass());
pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeHostConstantsPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass());
@@ -8,6 +8,7 @@
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
@@ -22,8 +23,8 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -51,7 +52,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeSingleChildCompute(func::FuncOp funcOp);
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
};
} // namespace
@@ -148,10 +149,10 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
mergeSingleChildCompute(*entryFunc);
mergeTriviallyConnectedComputes(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial");
dumpModule(moduleOp, "spatial0");
}
template <typename T>
@@ -230,66 +231,61 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
}
}
void ONNXToSpatialPass::mergeSingleChildCompute(func::FuncOp funcOp) {
llvm::SmallVector<spatial::SpatWeightedCompute> computeSingleChild;
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) {
if (compute->hasOneUse()) {
auto user = *compute->getUsers().begin();
if (user->getNumOperands() == 1)
if (llvm::isa<spatial::SpatWeightedCompute>(user))
computeSingleChild.push_back(compute);
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
trivialComputes.push_back(compute);
}
IRMapping mapper;
while (!computeSingleChild.empty()) {
auto compute = computeSingleChild.front();
auto child = dyn_cast_if_present<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
assert(child && "Child required!");
while (!trivialComputes.empty()) {
auto compute = trivialComputes.front();
if (compute.use_empty()) {
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{(int) compute.getWeights().size(), (int) compute.getInputs().size()});
llvm::dbgs() << "After Creation\n";
newCompute.dump();
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper;
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
llvm::dbgs() << "After Clone\n";
newCompute.dump();
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
newTerminator->erase();
llvm::dbgs() << "After terminator\n";
newCompute.dump();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front())
rewriter.clone(op, mapper);
child.replaceAllUsesWith(newCompute);
assert(child->getUses().empty() && "It's not obvius");
llvm::dbgs() << "Node\n";
newCompute.dump();
toErase.insert(child);
llvm::dbgs() << "Parent\n";
compute.dump();
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
toErase.insert(compute);
llvm::dbgs() << "Child\n";
child.dump();
child.erase();
compute.erase();
if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) {
if (newCompute->hasOneUse()) {
auto user = *newCompute->getUsers().begin();
if (user->getNumOperands() == 1)
if (llvm::isa<spatial::SpatWeightedCompute>(user))
computeSingleChild.push_back(newCompute);
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
trivialComputes.push_back(newCompute);
}
std::swap(computeSingleChild.front(), computeSingleChild.back());
computeSingleChild.pop_back();
}
for (auto compute : toErase) {
compute.getResult(0).dropAllUses();
compute.erase();
}
}
@@ -1,12 +1,16 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -24,122 +28,150 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override;
};
} // namespace
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
return nullptr;
}
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
assert("Only support 2D convolution" && xType.getRank() == 4);
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
// We need to understand what is group
assert("Only support group=1" && convOp.getGroup() == 1);
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1)
return bias;
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType());
return tensor::ExpandShapeOp::create(rewriter,
loc,
expandedBiasType,
bias,
SmallVector<ReassociationIndices> {
{0, 1}
});
}
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
static Value createPaddedRows(Value tensorValue,
RankedTensorType tensorType,
int64_t paddedRows,
ConversionPatternRewriter& rewriter,
Location loc) {
if (tensorType.getDimSize(0) == paddedRows)
return tensorValue;
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
rewriter.getIndexAttr(0)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 2; i++)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
Value wTrans,
RankedTensorType wType,
int64_t numChannelsIn,
int64_t numChannelsOut,
int64_t wHeight,
int64_t wWidth,
int64_t patchSize,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return wTrans;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
auto packedWeightType =
RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType());
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
cast<Attribute>(rewriter.getZeroAttr(wType.getElementType())));
if (padsAttr) {
padHeightBegin = getI64(*padsAttr, 0);
padWidthBegin = getI64(*padsAttr, 1);
padHeightEnd = getI64(*padsAttr, 2);
padWidthEnd = getI64(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
for (int64_t copyId = 0; copyId < packFactor; copyId++) {
for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) {
for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) {
for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) {
for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) {
const int64_t sourceFlatIndex =
(((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW;
const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW;
const int64_t targetRow = copyId * patchSize + patchIndex;
const int64_t targetCol = copyId * numChannelsOut + outChannel;
packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex];
}
}
}
}
// "NOTSET" or "VALID" -> all pads stay 0
}
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
}
static Value buildPackedBias(bool hasBias,
Value gemmBias,
Value biasMatrix,
DenseElementsAttr biasDenseAttr,
RankedTensorType outType,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (!hasBias)
return gemmBias;
if (packFactor == 1)
return biasMatrix;
SmallVector<Attribute> sourceValues(biasDenseAttr.getValues<Attribute>());
SmallVector<Attribute> packedValues;
packedValues.reserve(packFactor * numChannelsOut);
for (int64_t copyId = 0; copyId < packFactor; copyId++)
packedValues.append(sourceValues.begin(), sourceValues.end());
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
}
static Value createIm2colCompute(Value x,
RankedTensorType xType,
RankedTensorType im2colType,
RankedTensorType rowType,
int64_t batchSize,
int64_t numChannelsIn,
int64_t xHeight,
int64_t xWidth,
int64_t wHeight,
int64_t wWidth,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t outWidth,
int64_t patchSize,
int64_t numPatches,
int64_t numPatchesPerBatch,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC;
if (hasB)
gemmC = b;
else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
constexpr size_t numInputs = 1;
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
Value paddedInput = xArg;
@@ -226,23 +258,104 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
Value im2col = im2colLoop.getResult(0);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
});
return im2colComputeOp.getResult(0);
}
// Gemm: A @ B + C = im2col @ W^T + b
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmOutType,
im2colComputeOp.getResult(0),
wTrans,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
Value gemmOut = gemmOp.getY();
static Value createPackedIm2colRows(Value im2col,
RankedTensorType im2colType,
Type elemType,
int64_t numPatches,
int64_t patchSize,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return im2col;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
});
return packedComputeOp.getResult(0);
}
static Value createUnpackedOutput(Value packedOutput,
RankedTensorType gemmOutType,
RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return packedOutput;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutputArg,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedOutput,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value unpackedOutput = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
unpackedOutput =
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
}
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
});
return unpackComputeOp.getResult(0);
}
static Value createCollectedConvOutput(Value gemmOut,
Type convType,
RankedTensorType nhwcType,
RankedTensorType outType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto collectComputeOp =
createSpatCompute<numInputs>(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) {
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
Value gemmOutArg = gemmOutArgs.front();
// Restore to NCHW layout:
// [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut]
@@ -256,11 +369,225 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
{3}
});
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
});
return collectComputeOp.getResult(0);
}
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
assert("Only support 2D convolution" && xType.getRank() == 4);
// We need to understand what is group
assert("Only support group=1" && convOp.getGroup() == 1);
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
// "NOTSET" or "VALID" -> all pads stay 0
}
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getDenseConstantAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
Value biasMatrix;
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmC = b;
biasDenseAttr = getDenseConstantAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
const int64_t effectiveMaxParallelPixels =
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
Value im2col = createIm2colCompute(x,
xType,
im2colType,
rowType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
rewriter,
loc);
Value gemmOut;
if (effectiveMaxParallelPixels == 1) {
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
gemmOut = ONNXGemmOp::create(rewriter,
loc,
gemmOutType,
im2col,
wTrans,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
}
else {
// Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// but repack several old rows into one new row so we use the available crossbar size better.
//
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto packedOutType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value packedA = createPackedIm2colRows(
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
Value packedB = buildPackedWeight(wDenseAttr,
wTrans,
wType,
numChannelsIn,
numChannelsOut,
wHeight,
wWidth,
patchSize,
effectiveMaxParallelPixels,
rewriter,
loc);
Value packedC = buildPackedBias(
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
Value packedOut = ONNXGemmOp::create(rewriter,
loc,
packedOutType,
packedA,
packedB,
packedC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
gemmOut = createUnpackedOutput(
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
}
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
return success();
}
@@ -1,17 +1,29 @@
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include <cassert>
#include <cstddef>
#include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
namespace {
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
assert(attr && "required precomputed channel attr is missing");
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
}
} // namespace
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/*
EXAMPLE RUN:
@@ -54,6 +66,45 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue;
}
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
}
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
}
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
}
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
}
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
}
mlir::Value createPimReceiveFromSpatialChannel(
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
}
Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers();
@@ -2,11 +2,16 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
@@ -21,6 +26,21 @@ namespace onnx_mlir {
*/
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
mlir::Value createPimReceiveFromSpatialChannel(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end());
@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def HasSpatialChannelSourceCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
"spatial channel has precomputed source core id">;
def HasSpatialChannelTargetCoreIdAttr: Constraint<
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
"spatial channel has precomputed target core id">;
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
@@ -69,4 +80,18 @@ def spatToPimVSoftmax : Pat<
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatChannelSendToPimSend : Pat<
(SpatChannelSendOp $channel, $input),
(PimSendOp $input,
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
[(HasSpatialChannelTargetCoreIdAttr $channel)]
>;
def spatChannelReceiveToPimReceive : Pat<
(SpatChannelReceiveOp:$srcOpRes $channel),
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
[(HasSpatialChannelSourceCoreIdAttr $channel)]
>;
#endif // SPATIAL_TO_PIM
@@ -10,8 +10,10 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_os_ostream.h"
@@ -68,6 +70,8 @@ private:
bool useBroadcastOp,
IRRewriter& rewriter);
void markOpToRemove(Operation* op);
void annotateChannelCoreIds(func::FuncOp funcOp);
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
@@ -175,6 +179,16 @@ void SpatialToPimPass::runOnOperation() {
runOnComputeOp(computeOp, rewriter);
}
annotateChannelCoreIds(funcOp);
lowerBroadcastChannelOps(funcOp, rewriter);
RewritePatternSet channelPatterns(ctx);
populateWithGenerated(channelPatterns);
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
signalPassFailure();
return;
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnOpOperands(returnOp, rewriter);
@@ -623,6 +637,94 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
operationsToRemove.push_back(op);
}
void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
markOpToRemove(channelNewOp);
if (channelNewOp->use_empty())
return;
spatial::SpatChannelSendOp sendOp;
spatial::SpatChannelReceiveOp receiveOp;
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
for (Operation* user : channelNewOp->getUsers()) {
if (auto op = dyn_cast<spatial::SpatChannelSendOp>(user)) {
sendOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelReceiveOp>(user)) {
receiveOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelBroadcastSendOp>(user)) {
broadcastSendOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
continue;
}
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
}
if (broadcastSendOp) {
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(broadcastSendOp->getParentOp()).getCoreIdAttr();
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
return;
}
if (!sendOp || !receiveOp)
llvm_unreachable("spat.channel_new must connect exactly one send and one receive");
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr);
});
}
void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatChannelBroadcastSendOp> broadcastSendOps;
funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); });
for (auto sendOp : broadcastSendOps) {
auto channelNewOp = cast<spatial::SpatChannelNewOp>(sendOp.getChannel().getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
rewriter.setInsertionPoint(sendOp);
bool foundReceiver = false;
for (Operation* user : channelNewOp->getUsers()) {
auto receiveOp = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user);
if (!receiveOp)
continue;
foundReceiver = true;
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
}
if (!foundReceiver)
llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive");
rewriter.eraseOp(sendOp);
}
SmallVector<spatial::SpatChannelBroadcastReceiveOp> broadcastReceiveOps;
funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); });
for (auto receiveOp : broadcastReceiveOps) {
rewriter.setInsertionPoint(receiveOp);
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
Value receivedValue =
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
rewriter.replaceOp(receiveOp, receivedValue);
}
}
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
for (auto it : llvm::enumerate(originalOperands)) {
@@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface
}
};
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto receiveOp = cast<PimReceiveOp>(op);
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdAttr());
return success();
}
};
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
+4 -5
View File
@@ -3,11 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps
SpatialOps.cpp
Transforms/SpatialBufferizableOpInterface.cpp
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
DCPGraph/Graph.cpp
DCPGraph/Task.cpp
DCPGraph/DCPAnalysis.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
EXCLUDE_FROM_OM_LIBS
@@ -8,7 +8,6 @@
#include <iterator>
#include "../SpatialOps.hpp"
#include "DCPAnalysis.hpp"
#include "Graph.hpp"
#include "src/Support/TypeUtilities.hpp"
@@ -6,7 +6,7 @@
#include <vector>
#include "../SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
struct DCPAnalysisResult {
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
@@ -7,11 +7,11 @@
#include <fstream>
#include <vector>
#include "../../../Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "DCPAnalysis.hpp"
#include "Graph.hpp"
#include "Task.hpp"
#include "Uniqueworklist.hpp"
#include "UniqueWorklist.hpp"
#include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
@@ -420,7 +420,7 @@ void GraphDCP::to_dot() {
std::string outputDir = onnx_mlir::getOutputDir();
if (outputDir.empty())
return;
std::string graphDir = outputDir + "/DCPGraph";
std::string graphDir = outputDir + "/dcp_graph";
onnx_mlir::createDirectory(graphDir);
std::fstream file(graphDir + "/graph_" + std::to_string(index++) + ".dot", std::ios::out);
file << "digraph G {\n";
@@ -2,7 +2,7 @@
#include "Graph.hpp"
#include "Task.hpp"
#include "Uniqueworklist.hpp"
#include "UniqueWorklist.hpp"
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
std::optional<Edge_t> oldEdge = std::nullopt;
@@ -5,7 +5,7 @@
#include <optional>
#include <vector>
#include "../SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
@@ -9,7 +9,7 @@
#include <utility>
#include <vector>
#include "../SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Support/TypeUtilities.hpp"
using CPU = int;
@@ -20,7 +20,7 @@
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
#include "DCPGraph/DCPAnalysis.hpp"
using namespace mlir;
@@ -81,7 +81,7 @@ public:
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
};
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
private:
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
@@ -89,11 +89,11 @@ private:
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
StringRef getArgument() const override { return "pim-merge-node-pass"; }
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
StringRef getDescription() const override {
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
"execution time";
}
@@ -133,7 +133,7 @@ public:
computeNodetoRemove.erase();
}
func::FuncOp func = getOperation();
dumpModule(cast<ModuleOp>(func->getParentOp()), "SpatialDCPMerged");
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
}
private:
@@ -346,6 +346,6 @@ private:
} // namespace
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
std::unique_ptr<Pass> createMergeComputeNodesPass() { return std::make_unique<MergeComputeNodesPass>(); }
} // namespace onnx_mlir
@@ -1,532 +0,0 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace bufferization;
namespace onnx_mlir {
namespace spatial {
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
auto resultShape = cast<ShapedType>(resultType);
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
// Alloc an output memref
return memref::AllocOp::create(rewriter, loc, memrefResultType);
}
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
if (succeeded(resolveContiguousAddress(memrefValue)))
return memrefValue;
auto shapedType = cast<ShapedType>(memrefValue.getType());
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
return pim::PimMemCopyOp::create(rewriter,
loc,
contiguousBuffer.getType(),
contiguousBuffer,
memrefValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
++usersIterator;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
++usersIterator;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
return failure();
}
Operation* otherUser = nullptr;
if (firstUser == op)
otherUser = secondUser;
else if (secondUser == op)
otherUser = firstUser;
else {
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
return failure();
}
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
return failure();
}
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
return failure();
}
return otherUser;
}
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
// This function requires the existence of ChannelNewOp and the other
// Receive/Send operation. However, during bufferization, the first of the
// Receive/Send operation that is processed gets removed. As such, we need to
// "precompute" the coreId needed for the other op, and save it as attribute
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
if (precomputedOtherCoreId)
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
if (failed(notOpUserOpt))
return failure();
Operation* notOpUser = *notOpUserOpt;
// Save the coreId for this op into the other op as attribute
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
}
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
// Input tensor to the compute OP are always read into its local memory
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensor to the compute OP are _never_ written into its local memory
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the compute OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Bufferize its block
auto& block = op->getRegion(0).front();
return bufferizeBlockSignature(&block, rewriter, options, state);
}
};
/*
* This can be used for operation that have a single argument, which is a
* variadic of tensors, and a single output with the same same shape
* Example: VAdd, VSub, VExp
*/
template <typename InterfaceName, typename OpTy, typename ToTy>
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
// Input tensors to the OP are always read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensors to the OP are _never_ written
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Cast tensor values into memref values
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Turn Tensor Operands into Memref Operands
SmallVector<Value> memrefOperands;
memrefOperands.reserve(op->getNumOperands());
for (auto operand : op->getOperands()) {
auto memref = getBuffer(rewriter, operand, options, state);
if (failed(memref))
return failure();
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
}
// TODO: Support addiction with more than 2 operands
if (memrefOperands.size() > 2) {
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
"with 1 or 2 operands, for now.");
return failure();
}
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
memrefOperands.push_back(outputTensor);
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
template <typename InterfaceName, typename OpTy, typename ToTy>
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
// Input tensors to the OP are always read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
// Input tensors to the OP are _never_ written
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// In general, no tensor is aliased with any other tensor in the OP
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return {};
}
// Cast tensor value into memref value
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(memrefOperandOpt))
return failure();
auto memrefOperand = *memrefOperandOpt;
// Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
Value newValue = ToTy::create(rewriter,
op->getLoc(),
outputTensor.getType(),
cast<OpTy>(op).getWeightIndexAttr(),
memrefOperand,
outputTensor)
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelReceiveOpInterface
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
// Input value is the channel (not read/written, its more of an attribute)
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel receive to pim.recv
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
if (failed(srcCoreId))
return failure();
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
// First input is channel (not read/writter) second input is Tensor to send,
// which is read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return opOperand.getOperandNumber() == 2;
}
// See above (both non-written)
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the channel send to pim.send
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
if (failed(srcTensorOpt))
return failure();
auto srcMemRef = *srcTensorOpt;
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
if (failed(dstCoreId))
return failure();
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
op,
srcMemRef,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(dstCoreId.value()));
return success();
}
};
struct ChannelBroadcastReceiveOpInterface
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
// Input value is the channel (not read/written, its more of an attribute)
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
if (precomputedOtherCoreId) {
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
cast<IntegerAttr>(precomputedOtherCoreId))
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
return failure();
}
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
for (Operation* user : channelNewOp->getUsers()) {
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
if (!sendOp)
continue;
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
}
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
return failure();
}();
if (failed(srcCoreId))
return failure();
Value newValue = pim::PimReceiveOp::create(rewriter,
op->getLoc(),
outputTensor.getType(),
outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue);
return success();
}
};
struct ChannelBroadcastSendOpInterface
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
// First input is channel (not read/writter) second input is Tensor to send,
// which is read
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return opOperand.getOperandNumber() == 2;
}
// See above (both non-written)
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
// See above
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
// TODO: Is it an empty list or a list of "UNKNOWN" values?
return {};
}
/*
* Turn the broadcast send into one pim.send per broadcast receiver.
*/
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
if (failed(srcTensorOpt))
return failure();
auto srcMemRef = *srcTensorOpt;
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
return failure();
}
auto srcType = cast<ShapedType>(srcTensor.getType());
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
rewriter.setInsertionPoint(op);
bool foundReceiver = false;
for (Operation* user : channelNewOp->getUsers()) {
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
if (!receiveOp)
continue;
foundReceiver = true;
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
pim::PimSendOp::create(rewriter,
op->getLoc(),
srcMemRef,
rewriter.getI32IntegerAttr(sizeInBytes),
rewriter.getI32IntegerAttr(dstCoreId));
}
if (!foundReceiver) {
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
return failure();
}
rewriter.eraseOp(op);
return success();
}
};
struct VAddOpInterfaceFromTemplate
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
});
}
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
struct ONNXSigmoidInterface
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
});
}
} // namespace spatial
} // namespace onnx_mlir
@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace spatial {
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
} // namespace spatial
} // namespace onnx_mlir
+7 -7
View File
@@ -1,13 +1,13 @@
add_pim_library(OMPimPasses
CountInstructionPass.cpp
MessagePass.cpp
Pim/ConstantFolding/Common.cpp
Pim/ConstantFolding/Patterns/Constant.cpp
Pim/ConstantFolding/ConstantFoldingPass.cpp
Pim/ConstantFolding/Patterns/Subview.cpp
Pim/MaterializeConstantsPass.cpp
Pim/VerificationPass.cpp
Pim/EmitPimJsonPass.cpp
PimCodegen/HostConstantFolding/Common.cpp
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
PimCodegen/MaterializeHostConstantsPass.cpp
PimCodegen/VerificationPass.cpp
PimCodegen/EmitPimJsonPass.cpp
EXCLUDE_FROM_OM_LIBS
+3 -3
View File
@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
std::unique_ptr<mlir::Pass> createMergeComputeNodePass();
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
std::unique_ptr<mlir::Pass> createPimVerificationPass();
@@ -11,10 +11,10 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
LogicalResult initialize(MLIRContext* context) override {
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
} // namespace
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
} // namespace onnx_mlir
@@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) {
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
StringRef getArgument() const override { return "materialize-pim-constants"; }
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
StringRef getDescription() const override {
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
}
@@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
} // namespace
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
return std::make_unique<MaterializeHostConstantsPass>();
}
} // namespace onnx_mlir
+3 -6
View File
@@ -18,7 +18,6 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Accelerators/PIM/PimAccelerator.hpp"
#include "src/Compiler/CompilerUtils.hpp"
@@ -67,8 +66,6 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerBufferizableOpInterfaceExternalModels(registry);
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
pim::registerOpBufferizationInterfaces(registry);
}
@@ -78,9 +75,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass);
registerPass(createPimBufferizationPass);
registerPass(createMergeComputeNodePass);
registerPass(createPimConstantFoldingPass);
registerPass(createPimMaterializeConstantsPass);
registerPass(createMergeComputeNodesPass);
registerPass(createPimHostConstantFoldingPass);
registerPass(createPimMaterializeHostConstantsPass);
registerPass(createPimVerificationPass);
registerPass(createEmitPimJsonPass);
}
-2
View File
@@ -1,5 +1,3 @@
Rimuovere la logica di bufferizazione da spatial,
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
AnalisiDCP
+54 -13
View File
@@ -12,7 +12,16 @@ from gen_network_runner import gen_network_runner
from subprocess_utils import run_command_with_reporter
STAGE_COUNT = 6
STAGE_TITLES = (
"Compile ONNX",
"Build Runner",
"Generate Inputs",
"Run Reference",
"Compile PIM",
"Run Simulator",
"Compare Outputs",
)
STAGE_COUNT = len(STAGE_TITLES)
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
@@ -22,6 +31,8 @@ class ProgressReporter:
self.stages_per_model = stages_per_model
self.total_steps = max(1, total_models * stages_per_model)
self.completed_steps = 0
self.passed_models = 0
self.failed_models = 0
self.current_label = ""
self.enabled = True
self.columns = shutil.get_terminal_size((100, 20)).columns
@@ -36,21 +47,42 @@ class ProgressReporter:
return
bar_width = 24
filled = int(bar_width * self.completed_steps / self.total_steps)
counts_text = f"P:{self.passed_models} F:{self.failed_models}"
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
if len(prefix_text) > self.columns:
prefix_text = f"{self.completed_steps}/{self.total_steps}"
label = f" {self.current_label}" if self.current_label else ""
available_label_width = max(0, self.columns - len(prefix_text))
label = label[:available_label_width]
if prefix_text.startswith("["):
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
else:
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
counts = (
" "
+ Style.BRIGHT
+ Fore.GREEN
+ f"P:{self.passed_models}"
+ Style.RESET_ALL
+ " "
+ Style.BRIGHT
+ Fore.RED
+ f"F:{self.failed_models}"
+ Style.RESET_ALL
)
model_counter = ""
label = ""
if self.current_label.startswith("[") and "] " in self.current_label:
model_counter, label = self.current_label.split("] ", 1)
model_counter = f" {model_counter}]"
label = f" {label}"
elif self.current_label:
label = f" {self.current_label}"
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
label = label[:available_label_width]
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
sys.stdout.flush()
def log(self, message="", color=None):
@@ -70,6 +102,13 @@ class ProgressReporter:
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
self._render()
def record_result(self, passed):
if passed:
self.passed_models += 1
else:
self.failed_models += 1
self._render()
def suspend(self):
self.suspended = True
self._clear()
@@ -112,13 +151,13 @@ def clean_workspace_artifacts(workspace_dir, model_stem):
def print_stage(reporter, model_index, model_total, model_name, title):
stage_colors = {
"Compile ONNX": Fore.BLUE,
"Build Runner": Fore.MAGENTA,
"Generate Inputs": Fore.YELLOW,
"Run Reference": Fore.GREEN,
"Compile PIM": Fore.CYAN,
"Run Simulator": Fore.MAGENTA,
"Compare Outputs": Fore.YELLOW,
STAGE_TITLES[0]: Fore.BLUE,
STAGE_TITLES[1]: Fore.MAGENTA,
STAGE_TITLES[2]: Fore.YELLOW,
STAGE_TITLES[3]: Fore.GREEN,
STAGE_TITLES[4]: Fore.CYAN,
STAGE_TITLES[5]: Fore.MAGENTA,
STAGE_TITLES[6]: Fore.YELLOW,
}
color = stage_colors.get(title, Fore.WHITE)
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
@@ -284,11 +323,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
reporter.resume()
reporter.advance()
reporter.record_result(passed)
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
return passed
except Exception:
failed_with_exception = True
reporter.record_result(False)
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
reporter.suspend()
raise