Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a7dee5b840 | |||
| 77fe293062 | |||
| 525792e545 | |||
| eade488d13 | |||
| 30ee9640d4 | |||
| 368e340a40 | |||
| e866ec6f87 |
@@ -33,7 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPim) {
|
if (pimEmissionTarget >= EmitPim) {
|
||||||
pm.addPass(createMergeComputeNodePass());
|
pm.addPass(createMergeComputeNodesPass());
|
||||||
pm.addPass(createSpatialToPimPass());
|
pm.addPass(createSpatialToPimPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||||
@@ -46,9 +46,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||||
pm.addPass(createPimConstantFoldingPass());
|
pm.addPass(createPimHostConstantFoldingPass());
|
||||||
pm.addPass(createMessagePass("Pim constants folded"));
|
pm.addPass(createMessagePass("Pim host constants folded"));
|
||||||
pm.addPass(createPimMaterializeConstantsPass());
|
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||||
pm.addPass(createPimVerificationPass());
|
pm.addPass(createPimVerificationPass());
|
||||||
pm.addPass(createMessagePass("Pim verified"));
|
pm.addPass(createMessagePass("Pim verified"));
|
||||||
pm.addPass(createEmitPimJsonPass());
|
pm.addPass(createEmitPimJsonPass());
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
@@ -22,8 +23,8 @@
|
|||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -51,7 +52,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
private:
|
private:
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||||
void mergeSingleChildCompute(func::FuncOp funcOp);
|
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -148,10 +149,10 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
encapsulateGlobalInstruction(*entryFunc);
|
||||||
mergeSingleChildCompute(*entryFunc);
|
mergeTriviallyConnectedComputes(*entryFunc);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "spatial");
|
dumpModule(moduleOp, "spatial0");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -230,66 +231,61 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::mergeSingleChildCompute(func::FuncOp funcOp) {
|
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||||
llvm::SmallVector<spatial::SpatWeightedCompute> computeSingleChild;
|
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
|
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
|
||||||
|
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||||
if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) {
|
if (compute->hasOneUse()) {
|
||||||
auto user = *compute->getUsers().begin();
|
auto user = *compute->getUsers().begin();
|
||||||
if (user->getNumOperands() == 1)
|
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
||||||
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
trivialComputes.push_back(compute);
|
||||||
computeSingleChild.push_back(compute);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
IRMapping mapper;
|
while (!trivialComputes.empty()) {
|
||||||
while (!computeSingleChild.empty()) {
|
auto compute = trivialComputes.front();
|
||||||
auto compute = computeSingleChild.front();
|
|
||||||
auto child = dyn_cast_if_present<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
if (compute.use_empty()) {
|
||||||
assert(child && "Child required!");
|
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||||
|
trivialComputes.pop_back();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
auto newCompute =
|
auto newCompute =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{(int) compute.getWeights().size(), (int) compute.getInputs().size()});
|
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||||
llvm::dbgs() << "After Creation\n";
|
|
||||||
newCompute.dump();
|
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||||
llvm::dbgs() << "After Clone\n";
|
|
||||||
newCompute.dump();
|
|
||||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||||
newTerminator->erase();
|
newTerminator->erase();
|
||||||
llvm::dbgs() << "After terminator\n";
|
|
||||||
newCompute.dump();
|
|
||||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||||
|
|
||||||
for (auto& op : child.getBody().front())
|
for (auto& op : child.getBody().front())
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
child.replaceAllUsesWith(newCompute);
|
child.replaceAllUsesWith(newCompute);
|
||||||
assert(child->getUses().empty() && "It's not obvius");
|
toErase.insert(child);
|
||||||
llvm::dbgs() << "Node\n";
|
|
||||||
newCompute.dump();
|
|
||||||
|
|
||||||
llvm::dbgs() << "Parent\n";
|
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||||
compute.dump();
|
trivialComputes.pop_back();
|
||||||
|
toErase.insert(compute);
|
||||||
|
|
||||||
llvm::dbgs() << "Child\n";
|
if (newCompute->hasOneUse()) {
|
||||||
child.dump();
|
|
||||||
|
|
||||||
child.erase();
|
|
||||||
compute.erase();
|
|
||||||
|
|
||||||
if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) {
|
|
||||||
auto user = *newCompute->getUsers().begin();
|
auto user = *newCompute->getUsers().begin();
|
||||||
if (user->getNumOperands() == 1)
|
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
||||||
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
trivialComputes.push_back(newCompute);
|
||||||
computeSingleChild.push_back(newCompute);
|
|
||||||
}
|
}
|
||||||
std::swap(computeSingleChild.front(), computeSingleChild.back());
|
}
|
||||||
computeSingleChild.pop_back();
|
|
||||||
|
for (auto compute : toErase) {
|
||||||
|
compute.getResult(0).dropAllUses();
|
||||||
|
compute.erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -24,122 +28,150 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const override;
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||||
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||||
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||||
|
|
||||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||||
ONNXConvOpAdaptor convOpAdaptor,
|
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||||
ConversionPatternRewriter& rewriter) const {
|
|
||||||
Location loc = convOp.getLoc();
|
|
||||||
Value x = convOpAdaptor.getX();
|
|
||||||
Value w = convOpAdaptor.getW();
|
|
||||||
Value b = convOpAdaptor.getB();
|
|
||||||
|
|
||||||
auto xType = cast<RankedTensorType>(x.getType());
|
return nullptr;
|
||||||
auto wType = cast<RankedTensorType>(w.getType());
|
}
|
||||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
|
||||||
|
|
||||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
|
||||||
|
|
||||||
// We need to understand what is group
|
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||||
|
if (biasType.getRank() != 1)
|
||||||
|
return bias;
|
||||||
|
|
||||||
const int64_t batchSize = xType.getDimSize(0);
|
auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType());
|
||||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
return tensor::ExpandShapeOp::create(rewriter,
|
||||||
const int64_t xHeight = xType.getDimSize(2);
|
|
||||||
const int64_t xWidth = xType.getDimSize(3);
|
|
||||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
|
||||||
const int64_t wHeight = wType.getDimSize(2);
|
|
||||||
const int64_t wWidth = wType.getDimSize(3);
|
|
||||||
const int64_t outHeight = outType.getDimSize(2);
|
|
||||||
const int64_t outWidth = outType.getDimSize(3);
|
|
||||||
|
|
||||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
|
||||||
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
|
|
||||||
|
|
||||||
const auto stridesAttr = convOp.getStrides();
|
|
||||||
const auto dilationsAttr = convOp.getDilations();
|
|
||||||
const auto padsAttr = convOp.getPads();
|
|
||||||
|
|
||||||
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
|
|
||||||
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
|
|
||||||
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
|
|
||||||
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
|
|
||||||
|
|
||||||
int64_t padHeightBegin = 0;
|
|
||||||
int64_t padHeightEnd = 0;
|
|
||||||
int64_t padWidthBegin = 0;
|
|
||||||
int64_t padWidthEnd = 0;
|
|
||||||
|
|
||||||
if (padsAttr) {
|
|
||||||
padHeightBegin = getI64(*padsAttr, 0);
|
|
||||||
padWidthBegin = getI64(*padsAttr, 1);
|
|
||||||
padHeightEnd = getI64(*padsAttr, 2);
|
|
||||||
padWidthEnd = getI64(*padsAttr, 3);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Compute padding from auto_pad attribute
|
|
||||||
const auto autoPad = convOp.getAutoPad();
|
|
||||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
|
||||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
|
||||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
|
||||||
const int64_t totalPadH =
|
|
||||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
|
||||||
const int64_t totalPadW =
|
|
||||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
|
||||||
|
|
||||||
if (autoPad == "SAME_UPPER") {
|
|
||||||
padHeightBegin = totalPadH / 2;
|
|
||||||
padHeightEnd = totalPadH - padHeightBegin;
|
|
||||||
padWidthBegin = totalPadW / 2;
|
|
||||||
padWidthEnd = totalPadW - padWidthBegin;
|
|
||||||
}
|
|
||||||
else { // SAME_LOWER
|
|
||||||
padHeightEnd = totalPadH / 2;
|
|
||||||
padHeightBegin = totalPadH - padHeightEnd;
|
|
||||||
padWidthEnd = totalPadW / 2;
|
|
||||||
padWidthBegin = totalPadW - padWidthEnd;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// "NOTSET" or "VALID" -> all pads stay 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
|
||||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
|
||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
|
||||||
// Gemm output: [numPatches, cOut]
|
|
||||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
|
||||||
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
|
||||||
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
|
||||||
|
|
||||||
auto elemType = xType.getElementType();
|
|
||||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
|
||||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
|
||||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
|
||||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
|
||||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
|
||||||
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
|
||||||
|
|
||||||
// Prepare weight matrix W for crossbar storage:
|
|
||||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
|
||||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
loc,
|
||||||
wFlatType,
|
expandedBiasType,
|
||||||
w,
|
bias,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0},
|
{0, 1}
|
||||||
{1, 2, 3}
|
|
||||||
});
|
});
|
||||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
}
|
||||||
|
|
||||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
static Value createPaddedRows(Value tensorValue,
|
||||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
RankedTensorType tensorType,
|
||||||
Value gemmC;
|
int64_t paddedRows,
|
||||||
if (hasB)
|
ConversionPatternRewriter& rewriter,
|
||||||
gemmC = b;
|
Location loc) {
|
||||||
else
|
if (tensorType.getDimSize(0) == paddedRows)
|
||||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
return tensorValue;
|
||||||
|
|
||||||
|
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
|
||||||
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||||
|
rewriter.getIndexAttr(0)};
|
||||||
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
||||||
|
auto* padBlock = new Block();
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
|
padOp.getRegion().push_back(padBlock);
|
||||||
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
|
auto zero = arith::ConstantOp::create(
|
||||||
|
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
|
||||||
|
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||||
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
|
return padOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||||
|
Value wTrans,
|
||||||
|
RankedTensorType wType,
|
||||||
|
int64_t numChannelsIn,
|
||||||
|
int64_t numChannelsOut,
|
||||||
|
int64_t wHeight,
|
||||||
|
int64_t wWidth,
|
||||||
|
int64_t patchSize,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (packFactor == 1)
|
||||||
|
return wTrans;
|
||||||
|
|
||||||
|
auto packedWeightType =
|
||||||
|
RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType());
|
||||||
|
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
|
||||||
|
cast<Attribute>(rewriter.getZeroAttr(wType.getElementType())));
|
||||||
|
|
||||||
|
for (int64_t copyId = 0; copyId < packFactor; copyId++) {
|
||||||
|
for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) {
|
||||||
|
for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) {
|
||||||
|
for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) {
|
||||||
|
for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) {
|
||||||
|
const int64_t sourceFlatIndex =
|
||||||
|
(((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW;
|
||||||
|
const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW;
|
||||||
|
const int64_t targetRow = copyId * patchSize + patchIndex;
|
||||||
|
const int64_t targetCol = copyId * numChannelsOut + outChannel;
|
||||||
|
packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildPackedBias(bool hasBias,
|
||||||
|
Value gemmBias,
|
||||||
|
Value biasMatrix,
|
||||||
|
DenseElementsAttr biasDenseAttr,
|
||||||
|
RankedTensorType outType,
|
||||||
|
int64_t numChannelsOut,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (!hasBias)
|
||||||
|
return gemmBias;
|
||||||
|
|
||||||
|
if (packFactor == 1)
|
||||||
|
return biasMatrix;
|
||||||
|
|
||||||
|
SmallVector<Attribute> sourceValues(biasDenseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> packedValues;
|
||||||
|
packedValues.reserve(packFactor * numChannelsOut);
|
||||||
|
for (int64_t copyId = 0; copyId < packFactor; copyId++)
|
||||||
|
packedValues.append(sourceValues.begin(), sourceValues.end());
|
||||||
|
|
||||||
|
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
|
||||||
|
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createIm2colCompute(Value x,
|
||||||
|
RankedTensorType xType,
|
||||||
|
RankedTensorType im2colType,
|
||||||
|
RankedTensorType rowType,
|
||||||
|
int64_t batchSize,
|
||||||
|
int64_t numChannelsIn,
|
||||||
|
int64_t xHeight,
|
||||||
|
int64_t xWidth,
|
||||||
|
int64_t wHeight,
|
||||||
|
int64_t wWidth,
|
||||||
|
int64_t padHeightBegin,
|
||||||
|
int64_t padHeightEnd,
|
||||||
|
int64_t padWidthBegin,
|
||||||
|
int64_t padWidthEnd,
|
||||||
|
int64_t strideHeight,
|
||||||
|
int64_t strideWidth,
|
||||||
|
int64_t dilationHeight,
|
||||||
|
int64_t dilationWidth,
|
||||||
|
int64_t outWidth,
|
||||||
|
int64_t patchSize,
|
||||||
|
int64_t numPatches,
|
||||||
|
int64_t numPatchesPerBatch,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto elemType = xType.getElementType();
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
||||||
Value paddedInput = xArg;
|
Value paddedInput = xArg;
|
||||||
@@ -226,23 +258,104 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
Value im2col = im2colLoop.getResult(0);
|
Value im2col = im2colLoop.getResult(0);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||||
});
|
});
|
||||||
|
return im2colComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
// Gemm: A @ B + C = im2col @ W^T + b
|
static Value createPackedIm2colRows(Value im2col,
|
||||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
RankedTensorType im2colType,
|
||||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
Type elemType,
|
||||||
|
int64_t numPatches,
|
||||||
|
int64_t patchSize,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (packFactor == 1)
|
||||||
|
return im2col;
|
||||||
|
|
||||||
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
|
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||||
|
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||||
|
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
|
||||||
|
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
|
||||||
|
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
gemmOutType,
|
groupedType,
|
||||||
im2colComputeOp.getResult(0),
|
paddedIm2col,
|
||||||
wTrans,
|
SmallVector<ReassociationIndices> {
|
||||||
gemmC,
|
{0, 1},
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
{2}
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
});
|
||||||
rewriter.getBoolAttr(false),
|
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
|
||||||
rewriter.getBoolAttr(false));
|
loc,
|
||||||
Value gemmOut = gemmOp.getY();
|
packedType,
|
||||||
|
groupedIm2col,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
|
||||||
|
});
|
||||||
|
return packedComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createUnpackedOutput(Value packedOutput,
|
||||||
|
RankedTensorType gemmOutType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
int64_t numPatches,
|
||||||
|
int64_t numChannelsOut,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (packFactor == 1)
|
||||||
|
return packedOutput;
|
||||||
|
|
||||||
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
|
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||||
|
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||||
|
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
|
||||||
|
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
expandedType,
|
||||||
|
packedOutputArg,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
paddedType,
|
||||||
|
expandedOutput,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
|
||||||
|
Value unpackedOutput = paddedOutput;
|
||||||
|
if (paddedNumPatches != numPatches) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
unpackedOutput =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
|
||||||
|
});
|
||||||
|
return unpackComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createCollectedConvOutput(Value gemmOut,
|
||||||
|
Type convType,
|
||||||
|
RankedTensorType nhwcType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
auto collectComputeOp =
|
auto collectComputeOp =
|
||||||
createSpatCompute<numInputs>(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) {
|
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
|
||||||
|
Value gemmOutArg = gemmOutArgs.front();
|
||||||
|
|
||||||
// Restore to NCHW layout:
|
// Restore to NCHW layout:
|
||||||
// [numPatches, numChannelsOut]
|
// [numPatches, numChannelsOut]
|
||||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||||
@@ -256,11 +369,225 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
{3}
|
{3}
|
||||||
});
|
});
|
||||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||||
});
|
});
|
||||||
|
return collectComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||||
|
ONNXConvOpAdaptor convOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Location loc = convOp.getLoc();
|
||||||
|
Value x = convOpAdaptor.getX();
|
||||||
|
Value w = convOpAdaptor.getW();
|
||||||
|
Value b = convOpAdaptor.getB();
|
||||||
|
|
||||||
|
auto xType = cast<RankedTensorType>(x.getType());
|
||||||
|
auto wType = cast<RankedTensorType>(w.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||||
|
|
||||||
|
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||||
|
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||||
|
|
||||||
|
// We need to understand what is group
|
||||||
|
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||||
|
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
|
const int64_t xHeight = xType.getDimSize(2);
|
||||||
|
const int64_t xWidth = xType.getDimSize(3);
|
||||||
|
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||||
|
const int64_t wHeight = wType.getDimSize(2);
|
||||||
|
const int64_t wWidth = wType.getDimSize(3);
|
||||||
|
const int64_t outHeight = outType.getDimSize(2);
|
||||||
|
const int64_t outWidth = outType.getDimSize(3);
|
||||||
|
|
||||||
|
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||||
|
const auto stridesAttr = convOp.getStrides();
|
||||||
|
const auto dilationsAttr = convOp.getDilations();
|
||||||
|
const auto padsAttr = convOp.getPads();
|
||||||
|
|
||||||
|
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||||
|
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||||
|
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||||
|
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||||
|
|
||||||
|
int64_t padHeightBegin = 0;
|
||||||
|
int64_t padHeightEnd = 0;
|
||||||
|
int64_t padWidthBegin = 0;
|
||||||
|
int64_t padWidthEnd = 0;
|
||||||
|
|
||||||
|
if (padsAttr) {
|
||||||
|
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||||
|
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||||
|
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||||
|
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Compute padding from auto_pad attribute
|
||||||
|
const auto autoPad = convOp.getAutoPad();
|
||||||
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||||
|
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||||
|
const int64_t totalPadH =
|
||||||
|
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||||
|
const int64_t totalPadW =
|
||||||
|
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||||
|
|
||||||
|
if (autoPad == "SAME_UPPER") {
|
||||||
|
padHeightBegin = totalPadH / 2;
|
||||||
|
padHeightEnd = totalPadH - padHeightBegin;
|
||||||
|
padWidthBegin = totalPadW / 2;
|
||||||
|
padWidthEnd = totalPadW - padWidthBegin;
|
||||||
|
}
|
||||||
|
else { // SAME_LOWER
|
||||||
|
padHeightEnd = totalPadH / 2;
|
||||||
|
padHeightBegin = totalPadH - padHeightEnd;
|
||||||
|
padWidthEnd = totalPadW / 2;
|
||||||
|
padWidthBegin = totalPadW - padWidthEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// "NOTSET" or "VALID" -> all pads stay 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||||
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
|
// Gemm output: [numPatches, cOut]
|
||||||
|
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||||
|
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||||
|
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||||
|
|
||||||
|
auto elemType = xType.getElementType();
|
||||||
|
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||||
|
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||||
|
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||||
|
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||||
|
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||||
|
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||||
|
|
||||||
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
|
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||||
|
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||||
|
auto wDenseAttr = getDenseConstantAttr(w);
|
||||||
|
|
||||||
|
// Prepare weight matrix W for crossbar storage:
|
||||||
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
|
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
wFlatType,
|
||||||
|
w,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
|
||||||
|
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||||
|
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
|
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
Value biasMatrix;
|
||||||
|
DenseElementsAttr biasDenseAttr;
|
||||||
|
if (hasB) {
|
||||||
|
gemmC = b;
|
||||||
|
biasDenseAttr = getDenseConstantAttr(b);
|
||||||
|
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||||
|
}
|
||||||
|
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||||
|
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
|
||||||
|
const int64_t effectiveMaxParallelPixels =
|
||||||
|
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
||||||
|
|
||||||
|
Value im2col = createIm2colCompute(x,
|
||||||
|
xType,
|
||||||
|
im2colType,
|
||||||
|
rowType,
|
||||||
|
batchSize,
|
||||||
|
numChannelsIn,
|
||||||
|
xHeight,
|
||||||
|
xWidth,
|
||||||
|
wHeight,
|
||||||
|
wWidth,
|
||||||
|
padHeightBegin,
|
||||||
|
padHeightEnd,
|
||||||
|
padWidthBegin,
|
||||||
|
padWidthEnd,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
outWidth,
|
||||||
|
patchSize,
|
||||||
|
numPatches,
|
||||||
|
numPatchesPerBatch,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
|
||||||
|
Value gemmOut;
|
||||||
|
if (effectiveMaxParallelPixels == 1) {
|
||||||
|
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
|
||||||
|
gemmOut = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmOutType,
|
||||||
|
im2col,
|
||||||
|
wTrans,
|
||||||
|
gemmC,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Keep the standard im2col view of convolution:
|
||||||
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
|
// but repack several old rows into one new row so we use the available crossbar size better.
|
||||||
|
//
|
||||||
|
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
|
||||||
|
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||||
|
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||||
|
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||||
|
// B_packed: [N * patchSize, N * cOut]
|
||||||
|
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||||
|
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
|
||||||
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||||
|
auto packedOutType =
|
||||||
|
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||||
|
|
||||||
|
Value packedA = createPackedIm2colRows(
|
||||||
|
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
Value packedB = buildPackedWeight(wDenseAttr,
|
||||||
|
wTrans,
|
||||||
|
wType,
|
||||||
|
numChannelsIn,
|
||||||
|
numChannelsOut,
|
||||||
|
wHeight,
|
||||||
|
wWidth,
|
||||||
|
patchSize,
|
||||||
|
effectiveMaxParallelPixels,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
Value packedC = buildPackedBias(
|
||||||
|
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
Value packedOut = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
packedOutType,
|
||||||
|
packedA,
|
||||||
|
packedB,
|
||||||
|
packedC,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
gemmOut = createUnpackedOutput(
|
||||||
|
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,29 @@
|
|||||||
#include "mlir/IR/ValueRange.h"
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
|
||||||
|
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
|
||||||
|
assert(attr && "required precomputed channel attr is missing");
|
||||||
|
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||||
/*
|
/*
|
||||||
EXAMPLE RUN:
|
EXAMPLE RUN:
|
||||||
@@ -54,6 +66,45 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
|||||||
return returnValue;
|
return returnValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
|
||||||
|
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||||
|
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||||
|
}
|
||||||
|
|
||||||
|
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||||
|
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||||
|
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value createPimReceiveFromSpatialChannel(
|
||||||
|
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||||
|
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||||
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||||
|
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||||
|
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||||
auto users = value.getUsers();
|
auto users = value.getUsers();
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,16 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
|
||||||
|
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||||
* its static tensor input.
|
* its static tensor input.
|
||||||
@@ -21,6 +26,21 @@ namespace onnx_mlir {
|
|||||||
*/
|
*/
|
||||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||||
|
|
||||||
|
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||||
|
|
||||||
|
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||||
|
|
||||||
|
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||||
|
|
||||||
|
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||||
|
|
||||||
|
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||||
|
|
||||||
|
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||||
|
|
||||||
|
mlir::Value createPimReceiveFromSpatialChannel(
|
||||||
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||||
return std::distance(range.begin(), range.end());
|
return std::distance(range.begin(), range.end());
|
||||||
|
|||||||
@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
|||||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||||
|
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||||
|
"spatial channel has precomputed source core id">;
|
||||||
|
|
||||||
|
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||||
|
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||||
|
"spatial channel has precomputed target core id">;
|
||||||
|
|
||||||
|
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||||
|
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||||
|
|
||||||
def onnxToPimTranspose : Pat<
|
def onnxToPimTranspose : Pat<
|
||||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||||
(PimTransposeOp $data, $perms,
|
(PimTransposeOp $data, $perms,
|
||||||
@@ -69,4 +80,18 @@ def spatToPimVSoftmax : Pat<
|
|||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
def spatChannelSendToPimSend : Pat<
|
||||||
|
(SpatChannelSendOp $channel, $input),
|
||||||
|
(PimSendOp $input,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||||
|
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||||
|
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatChannelReceiveToPimReceive : Pat<
|
||||||
|
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||||
|
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||||
|
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||||
|
>;
|
||||||
|
|
||||||
#endif // SPATIAL_TO_PIM
|
#endif // SPATIAL_TO_PIM
|
||||||
|
|||||||
@@ -10,8 +10,10 @@
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
#include "llvm/ADT/SmallSet.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
@@ -68,6 +70,8 @@ private:
|
|||||||
bool useBroadcastOp,
|
bool useBroadcastOp,
|
||||||
IRRewriter& rewriter);
|
IRRewriter& rewriter);
|
||||||
void markOpToRemove(Operation* op);
|
void markOpToRemove(Operation* op);
|
||||||
|
void annotateChannelCoreIds(func::FuncOp funcOp);
|
||||||
|
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||||
|
|
||||||
@@ -175,6 +179,16 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
runOnComputeOp(computeOp, rewriter);
|
runOnComputeOp(computeOp, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
annotateChannelCoreIds(funcOp);
|
||||||
|
lowerBroadcastChannelOps(funcOp, rewriter);
|
||||||
|
|
||||||
|
RewritePatternSet channelPatterns(ctx);
|
||||||
|
populateWithGenerated(channelPatterns);
|
||||||
|
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||||
replaceReturnOpOperands(returnOp, rewriter);
|
replaceReturnOpOperands(returnOp, rewriter);
|
||||||
|
|
||||||
@@ -623,6 +637,94 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
|
|||||||
operationsToRemove.push_back(op);
|
operationsToRemove.push_back(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
||||||
|
funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) {
|
||||||
|
markOpToRemove(channelNewOp);
|
||||||
|
|
||||||
|
if (channelNewOp->use_empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
spatial::SpatChannelSendOp sendOp;
|
||||||
|
spatial::SpatChannelReceiveOp receiveOp;
|
||||||
|
spatial::SpatChannelBroadcastSendOp broadcastSendOp;
|
||||||
|
|
||||||
|
for (Operation* user : channelNewOp->getUsers()) {
|
||||||
|
if (auto op = dyn_cast<spatial::SpatChannelSendOp>(user)) {
|
||||||
|
sendOp = op;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto op = dyn_cast<spatial::SpatChannelReceiveOp>(user)) {
|
||||||
|
receiveOp = op;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto op = dyn_cast<spatial::SpatChannelBroadcastSendOp>(user)) {
|
||||||
|
broadcastSendOp = op;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (broadcastSendOp) {
|
||||||
|
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(broadcastSendOp->getParentOp()).getCoreIdAttr();
|
||||||
|
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sendOp || !receiveOp)
|
||||||
|
llvm_unreachable("spat.channel_new must connect exactly one send and one receive");
|
||||||
|
|
||||||
|
auto sourceCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
||||||
|
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
||||||
|
channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr);
|
||||||
|
channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
|
SmallVector<spatial::SpatChannelBroadcastSendOp> broadcastSendOps;
|
||||||
|
funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); });
|
||||||
|
|
||||||
|
for (auto sendOp : broadcastSendOps) {
|
||||||
|
auto channelNewOp = cast<spatial::SpatChannelNewOp>(sendOp.getChannel().getDefiningOp());
|
||||||
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(sendOp);
|
||||||
|
bool foundReceiver = false;
|
||||||
|
for (Operation* user : channelNewOp->getUsers()) {
|
||||||
|
auto receiveOp = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user);
|
||||||
|
if (!receiveOp)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
foundReceiver = true;
|
||||||
|
auto targetCoreIdAttr = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreIdAttr();
|
||||||
|
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!foundReceiver)
|
||||||
|
llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive");
|
||||||
|
|
||||||
|
rewriter.eraseOp(sendOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatChannelBroadcastReceiveOp> broadcastReceiveOps;
|
||||||
|
funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); });
|
||||||
|
|
||||||
|
for (auto receiveOp : broadcastReceiveOps) {
|
||||||
|
rewriter.setInsertionPoint(receiveOp);
|
||||||
|
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
||||||
|
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||||
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
||||||
|
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
|
||||||
|
Value receivedValue =
|
||||||
|
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||||
|
.getOutput();
|
||||||
|
rewriter.replaceOp(receiveOp, receivedValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||||
for (auto it : llvm::enumerate(originalOperands)) {
|
for (auto it : llvm::enumerate(originalOperands)) {
|
||||||
|
|||||||
@@ -97,6 +97,31 @@ struct MemCopyDevToHostOpInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto receiveOp = cast<PimReceiveOp>(op);
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
||||||
|
op,
|
||||||
|
outputBufferOpt->getType(),
|
||||||
|
*outputBufferOpt,
|
||||||
|
receiveOp.getSizeAttr(),
|
||||||
|
receiveOp.getSourceCoreIdAttr());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
@@ -258,6 +283,7 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
|||||||
|
|
||||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||||
|
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||||
|
|||||||
@@ -3,11 +3,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
|
|||||||
|
|
||||||
add_pim_library(SpatialOps
|
add_pim_library(SpatialOps
|
||||||
SpatialOps.cpp
|
SpatialOps.cpp
|
||||||
Transforms/SpatialBufferizableOpInterface.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
|
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||||
DCPGraph/Graph.cpp
|
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
|
||||||
DCPGraph/Task.cpp
|
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
|
||||||
DCPGraph/DCPAnalysis.cpp
|
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
-1
@@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
|
||||||
#include "../SpatialOps.hpp"
|
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
#include "Graph.hpp"
|
#include "Graph.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
+1
-1
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
struct DCPAnalysisResult {
|
struct DCPAnalysisResult {
|
||||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
||||||
+3
-3
@@ -7,11 +7,11 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../../../Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
#include "Graph.hpp"
|
#include "Graph.hpp"
|
||||||
#include "Task.hpp"
|
#include "Task.hpp"
|
||||||
#include "Uniqueworklist.hpp"
|
#include "UniqueWorklist.hpp"
|
||||||
#include "Utils.hpp"
|
#include "Utils.hpp"
|
||||||
|
|
||||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
|
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
|
||||||
@@ -420,7 +420,7 @@ void GraphDCP::to_dot() {
|
|||||||
std::string outputDir = onnx_mlir::getOutputDir();
|
std::string outputDir = onnx_mlir::getOutputDir();
|
||||||
if (outputDir.empty())
|
if (outputDir.empty())
|
||||||
return;
|
return;
|
||||||
std::string graphDir = outputDir + "/DCPGraph";
|
std::string graphDir = outputDir + "/dcp_graph";
|
||||||
onnx_mlir::createDirectory(graphDir);
|
onnx_mlir::createDirectory(graphDir);
|
||||||
std::fstream file(graphDir + "/graph_" + std::to_string(index++) + ".dot", std::ios::out);
|
std::fstream file(graphDir + "/graph_" + std::to_string(index++) + ".dot", std::ios::out);
|
||||||
file << "digraph G {\n";
|
file << "digraph G {\n";
|
||||||
+1
-1
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include "Graph.hpp"
|
#include "Graph.hpp"
|
||||||
#include "Task.hpp"
|
#include "Task.hpp"
|
||||||
#include "Uniqueworklist.hpp"
|
#include "UniqueWorklist.hpp"
|
||||||
|
|
||||||
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
|
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
|
||||||
std::optional<Edge_t> oldEdge = std::nullopt;
|
std::optional<Edge_t> oldEdge = std::nullopt;
|
||||||
+1
-1
@@ -5,7 +5,7 @@
|
|||||||
#include <optional>
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "Utils.hpp"
|
#include "Utils.hpp"
|
||||||
|
|
||||||
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||||
+1
-1
@@ -9,7 +9,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
|
|
||||||
using CPU = int;
|
using CPU = int;
|
||||||
+7
-7
@@ -20,7 +20,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
#include "DCPGraph/DCPAnalysis.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ public:
|
|||||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
|
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||||
@@ -89,11 +89,11 @@ private:
|
|||||||
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "pim-merge-node-pass"; }
|
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
||||||
StringRef getDescription() const override {
|
StringRef getDescription() const override {
|
||||||
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
|
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
|
||||||
"execution time";
|
"execution time";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ public:
|
|||||||
computeNodetoRemove.erase();
|
computeNodetoRemove.erase();
|
||||||
}
|
}
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "SpatialDCPMerged");
|
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -346,6 +346,6 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
|
std::unique_ptr<Pass> createMergeComputeNodesPass() { return std::make_unique<MergeComputeNodesPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -1,532 +0,0 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
||||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace bufferization;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
namespace spatial {
|
|
||||||
|
|
||||||
memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase& rewriter) {
|
|
||||||
auto resultShape = cast<ShapedType>(resultType);
|
|
||||||
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
|
||||||
|
|
||||||
// Alloc an output memref
|
|
||||||
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
|
||||||
}
|
|
||||||
|
|
||||||
Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) {
|
|
||||||
if (succeeded(resolveContiguousAddress(memrefValue)))
|
|
||||||
return memrefValue;
|
|
||||||
|
|
||||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
|
||||||
auto contiguousBuffer = createEmptyFromType(memrefValue.getType(), loc, rewriter);
|
|
||||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
|
||||||
|
|
||||||
return pim::PimMemCopyOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
contiguousBuffer.getType(),
|
|
||||||
contiguousBuffer,
|
|
||||||
memrefValue,
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
|
||||||
.getOutput();
|
|
||||||
}
|
|
||||||
|
|
||||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
|
||||||
|
|
||||||
static FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive) {
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
|
||||||
if (!channelNewOp) {
|
|
||||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto channelUsers = channelNewOp->getUsers();
|
|
||||||
auto usersIterator = channelUsers.begin();
|
|
||||||
auto firstUser = *usersIterator;
|
|
||||||
++usersIterator;
|
|
||||||
if (usersIterator == channelUsers.end()) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, only one found.");
|
|
||||||
channelNewOp->dump();
|
|
||||||
op->dump();
|
|
||||||
channelNewOp->getParentOp()->dump();
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto secondUser = *usersIterator;
|
|
||||||
++usersIterator;
|
|
||||||
if (usersIterator != channelUsers.end()) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users, more than two found.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* otherUser = nullptr;
|
|
||||||
if (firstUser == op)
|
|
||||||
otherUser = secondUser;
|
|
||||||
else if (secondUser == op)
|
|
||||||
otherUser = firstUser;
|
|
||||||
else {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp must have two users and one of them must be the current op.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (opIsReceive && !isa<spatial::SpatChannelSendOp>(otherUser)) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelSendOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!opIsReceive && !isa<spatial::SpatChannelReceiveOp>(otherUser)) {
|
|
||||||
op->emitError("Operand generated by ChannelNewOp has two users, but the other one is not a ChannelReceiveOp.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
return otherUser;
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::FailureOr<uint32_t> getCoreIdOfOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
|
||||||
|
|
||||||
// This function requires the existence of ChannelNewOp and the other
|
|
||||||
// Receive/Send operation. However, during bufferization, the first of the
|
|
||||||
// Receive/Send operation that is processed gets removed. As such, we need to
|
|
||||||
// "precompute" the coreId needed for the other op, and save it as attribute
|
|
||||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
|
||||||
if (precomputedOtherCoreId)
|
|
||||||
return cast<IntegerAttr>(precomputedOtherCoreId).getInt();
|
|
||||||
|
|
||||||
auto notOpUserOpt = getOtherEndOfChannel(op, opIsReceive);
|
|
||||||
if (failed(notOpUserOpt))
|
|
||||||
return failure();
|
|
||||||
Operation* notOpUser = *notOpUserOpt;
|
|
||||||
|
|
||||||
// Save the coreId for this op into the other op as attribute
|
|
||||||
auto opCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
|
||||||
notOpUser->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, opCoreIdAttr);
|
|
||||||
|
|
||||||
return cast<pim::PimCoreOp>(notOpUser->getParentOp()).getCoreId();
|
|
||||||
}
|
|
||||||
|
|
||||||
struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpInterface, SpatWeightedCompute> {
|
|
||||||
|
|
||||||
// Input tensor to the compute OP are always read into its local memory
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
|
||||||
|
|
||||||
// Input tensor to the compute OP are _never_ written into its local memory
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// In general, no tensor is aliased with any other tensor in the compute OP
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
// Bufferize its block
|
|
||||||
|
|
||||||
auto& block = op->getRegion(0).front();
|
|
||||||
|
|
||||||
return bufferizeBlockSignature(&block, rewriter, options, state);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
* This can be used for operation that have a single argument, which is a
|
|
||||||
* variadic of tensors, and a single output with the same same shape
|
|
||||||
* Example: VAdd, VSub, VExp
|
|
||||||
*/
|
|
||||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
|
||||||
struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
|
||||||
|
|
||||||
// Input tensors to the OP are always read
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
|
||||||
|
|
||||||
// Input tensors to the OP are _never_ written
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// In general, no tensor is aliased with any other tensor in the OP
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cast tensor values into memref values
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
|
|
||||||
// Turn Tensor Operands into Memref Operands
|
|
||||||
SmallVector<Value> memrefOperands;
|
|
||||||
memrefOperands.reserve(op->getNumOperands());
|
|
||||||
for (auto operand : op->getOperands()) {
|
|
||||||
auto memref = getBuffer(rewriter, operand, options, state);
|
|
||||||
if (failed(memref))
|
|
||||||
return failure();
|
|
||||||
memrefOperands.push_back(materializeContiguousMemRef(*memref, op->getLoc(), rewriter));
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Support addiction with more than 2 operands
|
|
||||||
if (memrefOperands.size() > 2) {
|
|
||||||
op->emitError("VariadicArgumentElementWiseOpInterface only supports OPs "
|
|
||||||
"with 1 or 2 operands, for now.");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Alloc an output memref
|
|
||||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
memrefOperands.push_back(outputTensor);
|
|
||||||
|
|
||||||
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutput();
|
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename InterfaceName, typename OpTy, typename ToTy>
|
|
||||||
struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalModel<InterfaceName, OpTy> {
|
|
||||||
|
|
||||||
// Input tensors to the OP are always read
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
|
||||||
|
|
||||||
// Input tensors to the OP are _never_ written
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// In general, no tensor is aliased with any other tensor in the OP
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cast tensor value into memref value
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
|
|
||||||
if (failed(memrefOperandOpt))
|
|
||||||
return failure();
|
|
||||||
auto memrefOperand = *memrefOperandOpt;
|
|
||||||
|
|
||||||
// Alloc an output memref
|
|
||||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
Value newValue = ToTy::create(rewriter,
|
|
||||||
op->getLoc(),
|
|
||||||
outputTensor.getType(),
|
|
||||||
cast<OpTy>(op).getWeightIndexAttr(),
|
|
||||||
memrefOperand,
|
|
||||||
outputTensor)
|
|
||||||
.getOutput();
|
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ChannelReceiveOpInterface
|
|
||||||
: BufferizableOpInterface::ExternalModel<ChannelReceiveOpInterface, SpatChannelReceiveOp> {
|
|
||||||
|
|
||||||
// Input value is the channel (not read/written, its more of an attribute)
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// See above
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// See above
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Turn the channel receive to pim.recv
|
|
||||||
*/
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
|
|
||||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
|
||||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
|
||||||
|
|
||||||
auto srcCoreId = getCoreIdOfOtherEndOfChannel(op, true, rewriter);
|
|
||||||
if (failed(srcCoreId))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
|
||||||
op->getLoc(),
|
|
||||||
outputTensor.getType(),
|
|
||||||
outputTensor,
|
|
||||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
|
||||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
|
||||||
.getOutput();
|
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSendOpInterface, SpatChannelSendOp> {
|
|
||||||
|
|
||||||
// First input is channel (not read/writter) second input is Tensor to send,
|
|
||||||
// which is read
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return opOperand.getOperandNumber() == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// See above (both non-written)
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// See above
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Turn the channel send to pim.send
|
|
||||||
*/
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto srcTensor = op->getOperand(1);
|
|
||||||
|
|
||||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
|
||||||
if (failed(srcTensorOpt))
|
|
||||||
return failure();
|
|
||||||
auto srcMemRef = *srcTensorOpt;
|
|
||||||
|
|
||||||
auto numElements = cast<ShapedType>(srcTensor.getType()).getNumElements();
|
|
||||||
auto elementSize = cast<ShapedType>(srcTensor.getType()).getElementTypeBitWidth() / 8;
|
|
||||||
|
|
||||||
auto dstCoreId = getCoreIdOfOtherEndOfChannel(op, false, rewriter);
|
|
||||||
if (failed(dstCoreId))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<pim::PimSendOp>(rewriter,
|
|
||||||
op,
|
|
||||||
srcMemRef,
|
|
||||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
|
||||||
rewriter.getI32IntegerAttr(dstCoreId.value()));
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ChannelBroadcastReceiveOpInterface
|
|
||||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastReceiveOpInterface, SpatChannelBroadcastReceiveOp> {
|
|
||||||
|
|
||||||
// Input value is the channel (not read/written, its more of an attribute)
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// See above
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// See above
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Turn the broadcast receive into a regular pim.receive from the broadcaster.
|
|
||||||
*/
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
|
|
||||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
auto numElements = cast<ShapedType>(outputTensor.getType()).getNumElements();
|
|
||||||
auto elementSize = cast<ShapedType>(outputTensor.getType()).getElementTypeBitWidth() / 8;
|
|
||||||
|
|
||||||
auto precomputedOtherCoreId = op->getAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME);
|
|
||||||
if (precomputedOtherCoreId) {
|
|
||||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
|
||||||
op->getLoc(),
|
|
||||||
outputTensor.getType(),
|
|
||||||
outputTensor,
|
|
||||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
|
||||||
cast<IntegerAttr>(precomputedOtherCoreId))
|
|
||||||
.getOutput();
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
|
||||||
if (!channelNewOp) {
|
|
||||||
op->emitError("ChannelBroadcastReceiveOp does not use a channel as operand");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto srcCoreId = [&]() -> FailureOr<uint32_t> {
|
|
||||||
for (Operation* user : channelNewOp->getUsers()) {
|
|
||||||
auto sendOp = dyn_cast<SpatChannelBroadcastSendOp>(user);
|
|
||||||
if (!sendOp)
|
|
||||||
continue;
|
|
||||||
auto sendCoreIdAttr = cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreIdAttr();
|
|
||||||
op->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, sendCoreIdAttr);
|
|
||||||
return cast<pim::PimCoreOp>(sendOp->getParentOp()).getCoreId();
|
|
||||||
}
|
|
||||||
op->emitError("ChannelBroadcastReceiveOp has no matching ChannelBroadcastSendOp");
|
|
||||||
return failure();
|
|
||||||
}();
|
|
||||||
if (failed(srcCoreId))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Value newValue = pim::PimReceiveOp::create(rewriter,
|
|
||||||
op->getLoc(),
|
|
||||||
outputTensor.getType(),
|
|
||||||
outputTensor,
|
|
||||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
|
||||||
rewriter.getI32IntegerAttr(srcCoreId.value()))
|
|
||||||
.getOutput();
|
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ChannelBroadcastSendOpInterface
|
|
||||||
: BufferizableOpInterface::ExternalModel<ChannelBroadcastSendOpInterface, SpatChannelBroadcastSendOp> {
|
|
||||||
|
|
||||||
// First input is channel (not read/writter) second input is Tensor to send,
|
|
||||||
// which is read
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return opOperand.getOperandNumber() == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// See above (both non-written)
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
|
||||||
|
|
||||||
// See above
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
// TODO: Is it an empty list or a list of "UNKNOWN" values?
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Turn the broadcast send into one pim.send per broadcast receiver.
|
|
||||||
*/
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto srcTensor = op->getOperand(1);
|
|
||||||
|
|
||||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
|
||||||
if (failed(srcTensorOpt))
|
|
||||||
return failure();
|
|
||||||
auto srcMemRef = *srcTensorOpt;
|
|
||||||
|
|
||||||
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
|
|
||||||
if (!channelNewOp) {
|
|
||||||
op->emitError("SpatChannelBroadcastSendOp does not use a channel as operand");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto srcType = cast<ShapedType>(srcTensor.getType());
|
|
||||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
|
||||||
auto srcCoreIdAttr = cast<pim::PimCoreOp>(op->getParentOp()).getCoreIdAttr();
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
|
||||||
bool foundReceiver = false;
|
|
||||||
for (Operation* user : channelNewOp->getUsers()) {
|
|
||||||
auto receiveOp = dyn_cast<SpatChannelBroadcastReceiveOp>(user);
|
|
||||||
if (!receiveOp)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
foundReceiver = true;
|
|
||||||
auto dstCoreId = cast<pim::PimCoreOp>(receiveOp->getParentOp()).getCoreId();
|
|
||||||
receiveOp->setAttr(PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME, srcCoreIdAttr);
|
|
||||||
pim::PimSendOp::create(rewriter,
|
|
||||||
op->getLoc(),
|
|
||||||
srcMemRef,
|
|
||||||
rewriter.getI32IntegerAttr(sizeInBytes),
|
|
||||||
rewriter.getI32IntegerAttr(dstCoreId));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!foundReceiver) {
|
|
||||||
op->emitError("SpatChannelBroadcastSendOp has no matching ChannelBroadcastReceiveOp");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct VAddOpInterfaceFromTemplate
|
|
||||||
: VariadicArgumentElementWiseOpInterface<VAddOpInterfaceFromTemplate, SpatVAddOp, pim::PimVVAddOp> {};
|
|
||||||
|
|
||||||
struct WVMMOpInterface : WeightedMultiplicationsOpInterface<WVMMOpInterface, SpatWeightedVMMOp, pim::PimVMMOp> {};
|
|
||||||
|
|
||||||
struct WMVMOpInterface : WeightedMultiplicationsOpInterface<WMVMOpInterface, SpatWeightedMVMOp, pim::PimMVMOp> {};
|
|
||||||
|
|
||||||
struct VMaxOpInterface : VariadicArgumentElementWiseOpInterface<VMaxOpInterface, SpatVMaxOp, pim::PimVVMaxOp> {};
|
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
|
||||||
registry.addExtension(+[](MLIRContext* ctx, SpatialDialect* dialect) {
|
|
||||||
SpatWeightedCompute::attachInterface<WComputeOpInterface>(*ctx);
|
|
||||||
SpatVAddOp::attachInterface<VAddOpInterfaceFromTemplate>(*ctx);
|
|
||||||
SpatWeightedVMMOp::attachInterface<WVMMOpInterface>(*ctx);
|
|
||||||
SpatWeightedMVMOp::attachInterface<WMVMOpInterface>(*ctx);
|
|
||||||
SpatVMaxOp::attachInterface<VMaxOpInterface>(*ctx);
|
|
||||||
SpatChannelReceiveOp::attachInterface<ChannelReceiveOpInterface>(*ctx);
|
|
||||||
SpatChannelSendOp::attachInterface<ChannelSendOpInterface>(*ctx);
|
|
||||||
SpatChannelBroadcastReceiveOp::attachInterface<ChannelBroadcastReceiveOpInterface>(*ctx);
|
|
||||||
SpatChannelBroadcastSendOp::attachInterface<ChannelBroadcastSendOpInterface>(*ctx);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ONNXReluInterface : VariadicArgumentElementWiseOpInterface<ONNXReluInterface, ONNXReluOp, pim::PimVReluOp> {};
|
|
||||||
|
|
||||||
struct ONNXTanhInterface : VariadicArgumentElementWiseOpInterface<ONNXTanhInterface, ONNXTanhOp, pim::PimVTanhOp> {};
|
|
||||||
|
|
||||||
struct ONNXSigmoidInterface
|
|
||||||
: VariadicArgumentElementWiseOpInterface<ONNXSigmoidInterface, ONNXSigmoidOp, pim::PimVSigmOp> {};
|
|
||||||
|
|
||||||
void registerONNXBufferizableOpInterfaceExternalModels(DialectRegistry& registry) {
|
|
||||||
registry.addExtension(+[](MLIRContext* ctx, ONNXDialect* dialect) {
|
|
||||||
ONNXReluOp::attachInterface<ONNXReluInterface>(*ctx);
|
|
||||||
ONNXTanhOp::attachInterface<ONNXTanhInterface>(*ctx);
|
|
||||||
ONNXSigmoidOp::attachInterface<ONNXSigmoidInterface>(*ctx);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace spatial
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/DialectRegistry.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
namespace spatial {
|
|
||||||
|
|
||||||
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
|
||||||
|
|
||||||
void registerONNXBufferizableOpInterfaceExternalModels(mlir::DialectRegistry& registry);
|
|
||||||
|
|
||||||
} // namespace spatial
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
add_pim_library(OMPimPasses
|
add_pim_library(OMPimPasses
|
||||||
CountInstructionPass.cpp
|
CountInstructionPass.cpp
|
||||||
MessagePass.cpp
|
MessagePass.cpp
|
||||||
Pim/ConstantFolding/Common.cpp
|
PimCodegen/HostConstantFolding/Common.cpp
|
||||||
Pim/ConstantFolding/Patterns/Constant.cpp
|
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
|
||||||
Pim/ConstantFolding/ConstantFoldingPass.cpp
|
PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp
|
||||||
Pim/ConstantFolding/Patterns/Subview.cpp
|
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
|
||||||
Pim/MaterializeConstantsPass.cpp
|
PimCodegen/MaterializeHostConstantsPass.cpp
|
||||||
Pim/VerificationPass.cpp
|
PimCodegen/VerificationPass.cpp
|
||||||
Pim/EmitPimJsonPass.cpp
|
PimCodegen/EmitPimJsonPass.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
|||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createMergeComputeNodePass();
|
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
|
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();
|
std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||||
|
|
||||||
|
|||||||
+4
-4
@@ -11,10 +11,10 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<ModuleOp>> {
|
struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldingPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostConstantFoldingPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "pim-constant-folding-pass"; }
|
StringRef getArgument() const override { return "pim-host-constant-folding-pass"; }
|
||||||
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; }
|
||||||
|
|
||||||
LogicalResult initialize(MLIRContext* context) override {
|
LogicalResult initialize(MLIRContext* context) override {
|
||||||
@@ -47,6 +47,6 @@ struct ConstantFoldingPass : PassWrapper<ConstantFoldingPass, OperationPass<Modu
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimConstantFoldingPass() { return std::make_unique<ConstantFoldingPass>(); }
|
std::unique_ptr<Pass> createPimHostConstantFoldingPass() { return std::make_unique<HostConstantFoldingPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
+6
-4
@@ -31,10 +31,10 @@ static int64_t getValueSizeInBytes(Value value) {
|
|||||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, OperationPass<ModuleOp>> {
|
struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeConstantsPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass)
|
||||||
|
|
||||||
StringRef getArgument() const override { return "materialize-pim-constants"; }
|
StringRef getArgument() const override { return "materialize-pim-host-constants"; }
|
||||||
StringRef getDescription() const override {
|
StringRef getDescription() const override {
|
||||||
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
return "Materialize explicit host-to-device copies for constant globals used by PIM runtime ops";
|
||||||
}
|
}
|
||||||
@@ -126,6 +126,8 @@ struct MaterializeConstantsPass : PassWrapper<MaterializeConstantsPass, Operatio
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimMaterializeConstantsPass() { return std::make_unique<MaterializeConstantsPass>(); }
|
std::unique_ptr<Pass> createPimMaterializeHostConstantsPass() {
|
||||||
|
return std::make_unique<MaterializeHostConstantsPass>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -18,7 +18,6 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
#include "src/Accelerators/PIM/PimAccelerator.hpp"
|
||||||
#include "src/Compiler/CompilerUtils.hpp"
|
#include "src/Compiler/CompilerUtils.hpp"
|
||||||
@@ -67,8 +66,6 @@ void PimAccelerator::registerDialects(mlir::DialectRegistry& registry) const {
|
|||||||
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
|
mlir::arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
mlir::bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
|
mlir::scf::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
spatial::registerBufferizableOpInterfaceExternalModels(registry);
|
|
||||||
spatial::registerONNXBufferizableOpInterfaceExternalModels(registry);
|
|
||||||
pim::registerOpBufferizationInterfaces(registry);
|
pim::registerOpBufferizationInterfaces(registry);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,9 +75,9 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
|||||||
registerPass(createSpatialToGraphvizPass);
|
registerPass(createSpatialToGraphvizPass);
|
||||||
registerPass(createSpatialToPimPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createPimBufferizationPass);
|
registerPass(createPimBufferizationPass);
|
||||||
registerPass(createMergeComputeNodePass);
|
registerPass(createMergeComputeNodesPass);
|
||||||
registerPass(createPimConstantFoldingPass);
|
registerPass(createPimHostConstantFoldingPass);
|
||||||
registerPass(createPimMaterializeConstantsPass);
|
registerPass(createPimMaterializeHostConstantsPass);
|
||||||
registerPass(createPimVerificationPass);
|
registerPass(createPimVerificationPass);
|
||||||
registerPass(createEmitPimJsonPass);
|
registerPass(createEmitPimJsonPass);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
Rimuovere la logica di bufferizazione da spatial,
|
|
||||||
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
|
Rimuovere la gestione delle send e recive da sptaialtopim (nuovo mergeNode)
|
||||||
|
|
||||||
AnalisiDCP
|
AnalisiDCP
|
||||||
|
|||||||
+54
-13
@@ -12,7 +12,16 @@ from gen_network_runner import gen_network_runner
|
|||||||
from subprocess_utils import run_command_with_reporter
|
from subprocess_utils import run_command_with_reporter
|
||||||
|
|
||||||
|
|
||||||
STAGE_COUNT = 6
|
STAGE_TITLES = (
|
||||||
|
"Compile ONNX",
|
||||||
|
"Build Runner",
|
||||||
|
"Generate Inputs",
|
||||||
|
"Run Reference",
|
||||||
|
"Compile PIM",
|
||||||
|
"Run Simulator",
|
||||||
|
"Compare Outputs",
|
||||||
|
)
|
||||||
|
STAGE_COUNT = len(STAGE_TITLES)
|
||||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||||
|
|
||||||
|
|
||||||
@@ -22,6 +31,8 @@ class ProgressReporter:
|
|||||||
self.stages_per_model = stages_per_model
|
self.stages_per_model = stages_per_model
|
||||||
self.total_steps = max(1, total_models * stages_per_model)
|
self.total_steps = max(1, total_models * stages_per_model)
|
||||||
self.completed_steps = 0
|
self.completed_steps = 0
|
||||||
|
self.passed_models = 0
|
||||||
|
self.failed_models = 0
|
||||||
self.current_label = ""
|
self.current_label = ""
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||||
@@ -36,21 +47,42 @@ class ProgressReporter:
|
|||||||
return
|
return
|
||||||
bar_width = 24
|
bar_width = 24
|
||||||
filled = int(bar_width * self.completed_steps / self.total_steps)
|
filled = int(bar_width * self.completed_steps / self.total_steps)
|
||||||
|
counts_text = f"P:{self.passed_models} F:{self.failed_models}"
|
||||||
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
|
prefix_text = f"[{'#' * filled}{'-' * (bar_width - filled)}] {self.completed_steps}/{self.total_steps}"
|
||||||
if len(prefix_text) > self.columns:
|
if len(prefix_text) > self.columns:
|
||||||
prefix_text = f"{self.completed_steps}/{self.total_steps}"
|
prefix_text = f"{self.completed_steps}/{self.total_steps}"
|
||||||
|
|
||||||
label = f" {self.current_label}" if self.current_label else ""
|
|
||||||
available_label_width = max(0, self.columns - len(prefix_text))
|
|
||||||
label = label[:available_label_width]
|
|
||||||
|
|
||||||
if prefix_text.startswith("["):
|
if prefix_text.startswith("["):
|
||||||
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
|
bar = Fore.GREEN + ("#" * filled) + Fore.CYAN + ("-" * (bar_width - filled))
|
||||||
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
|
prefix = Fore.CYAN + f"[{bar}{Fore.CYAN}] {self.completed_steps}/{self.total_steps}" + Style.RESET_ALL
|
||||||
else:
|
else:
|
||||||
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
|
prefix = Fore.CYAN + prefix_text + Style.RESET_ALL
|
||||||
|
|
||||||
sys.stdout.write("\r" + prefix + label + Style.RESET_ALL)
|
counts = (
|
||||||
|
" "
|
||||||
|
+ Style.BRIGHT
|
||||||
|
+ Fore.GREEN
|
||||||
|
+ f"P:{self.passed_models}"
|
||||||
|
+ Style.RESET_ALL
|
||||||
|
+ " "
|
||||||
|
+ Style.BRIGHT
|
||||||
|
+ Fore.RED
|
||||||
|
+ f"F:{self.failed_models}"
|
||||||
|
+ Style.RESET_ALL
|
||||||
|
)
|
||||||
|
model_counter = ""
|
||||||
|
label = ""
|
||||||
|
if self.current_label.startswith("[") and "] " in self.current_label:
|
||||||
|
model_counter, label = self.current_label.split("] ", 1)
|
||||||
|
model_counter = f" {model_counter}]"
|
||||||
|
label = f" {label}"
|
||||||
|
elif self.current_label:
|
||||||
|
label = f" {self.current_label}"
|
||||||
|
|
||||||
|
available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3)
|
||||||
|
label = label[:available_label_width]
|
||||||
|
|
||||||
|
sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
def log(self, message="", color=None):
|
def log(self, message="", color=None):
|
||||||
@@ -70,6 +102,13 @@ class ProgressReporter:
|
|||||||
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
|
self.completed_steps = min(self.total_steps, self.completed_steps + 1)
|
||||||
self._render()
|
self._render()
|
||||||
|
|
||||||
|
def record_result(self, passed):
|
||||||
|
if passed:
|
||||||
|
self.passed_models += 1
|
||||||
|
else:
|
||||||
|
self.failed_models += 1
|
||||||
|
self._render()
|
||||||
|
|
||||||
def suspend(self):
|
def suspend(self):
|
||||||
self.suspended = True
|
self.suspended = True
|
||||||
self._clear()
|
self._clear()
|
||||||
@@ -112,13 +151,13 @@ def clean_workspace_artifacts(workspace_dir, model_stem):
|
|||||||
|
|
||||||
def print_stage(reporter, model_index, model_total, model_name, title):
|
def print_stage(reporter, model_index, model_total, model_name, title):
|
||||||
stage_colors = {
|
stage_colors = {
|
||||||
"Compile ONNX": Fore.BLUE,
|
STAGE_TITLES[0]: Fore.BLUE,
|
||||||
"Build Runner": Fore.MAGENTA,
|
STAGE_TITLES[1]: Fore.MAGENTA,
|
||||||
"Generate Inputs": Fore.YELLOW,
|
STAGE_TITLES[2]: Fore.YELLOW,
|
||||||
"Run Reference": Fore.GREEN,
|
STAGE_TITLES[3]: Fore.GREEN,
|
||||||
"Compile PIM": Fore.CYAN,
|
STAGE_TITLES[4]: Fore.CYAN,
|
||||||
"Run Simulator": Fore.MAGENTA,
|
STAGE_TITLES[5]: Fore.MAGENTA,
|
||||||
"Compare Outputs": Fore.YELLOW,
|
STAGE_TITLES[6]: Fore.YELLOW,
|
||||||
}
|
}
|
||||||
color = stage_colors.get(title, Fore.WHITE)
|
color = stage_colors.get(title, Fore.WHITE)
|
||||||
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
|
reporter.log(Style.BRIGHT + color + f"[{title}]" + Style.RESET_ALL)
|
||||||
@@ -284,11 +323,13 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
|
passed = validate_outputs(sim_arrays, out_dir, outputs_descriptor, threshold)
|
||||||
reporter.resume()
|
reporter.resume()
|
||||||
reporter.advance()
|
reporter.advance()
|
||||||
|
reporter.record_result(passed)
|
||||||
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
||||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||||
return passed
|
return passed
|
||||||
except Exception:
|
except Exception:
|
||||||
failed_with_exception = True
|
failed_with_exception = True
|
||||||
|
reporter.record_result(False)
|
||||||
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
|
reporter.log(Style.BRIGHT + Fore.RED + "Result: FAIL" + Style.RESET_ALL)
|
||||||
reporter.suspend()
|
reporter.suspend()
|
||||||
raise
|
raise
|
||||||
|
|||||||
Reference in New Issue
Block a user