rework actually broken dcp merge + compute re-batching (still to refine)
This commit is contained in:
@@ -135,7 +135,7 @@ validate.py \
|
||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir ../onnx-mlir/include \
|
||||
--operations-dir ./networks/yolo11n/depth_04 \
|
||||
--crossbar-size 2048 --crossbar-count 256
|
||||
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||
```
|
||||
|
||||
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||
|
||||
@@ -23,21 +23,42 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
unsigned inputCount = compute.getInputs().size();
|
||||
if (inputCount == 0)
|
||||
return std::nullopt;
|
||||
|
||||
unsigned inputBegin = compute->getNumOperands() - inputCount;
|
||||
if (operandNumber < inputBegin)
|
||||
return std::nullopt;
|
||||
return operandNumber - inputBegin;
|
||||
}
|
||||
|
||||
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner)) {
|
||||
unsigned inputCount = computeBatch.getInputs().size();
|
||||
if (inputCount == 0)
|
||||
return std::nullopt;
|
||||
|
||||
unsigned inputBegin = computeBatch->getNumOperands() - inputCount;
|
||||
if (operandNumber < inputBegin)
|
||||
return std::nullopt;
|
||||
return operandNumber - inputBegin;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
||||
Location loc = extractSliceOp.getLoc();
|
||||
|
||||
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
||||
return failure();
|
||||
|
||||
for (auto& uses : extractSliceOp->getUses()) {
|
||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
|
||||
if (spatCompute.getInputs().empty())
|
||||
return failure();
|
||||
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
|
||||
if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||
return failure();
|
||||
}
|
||||
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
||||
@@ -50,7 +71,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
@@ -69,7 +93,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
@@ -165,8 +192,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
|
||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
@@ -183,8 +212,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
|
||||
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
@@ -201,7 +232,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
else {
|
||||
{
|
||||
{
|
||||
|
||||
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
@@ -240,8 +271,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
|
||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
@@ -253,8 +286,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
|
||||
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
@@ -265,11 +300,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
else {
|
||||
if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
||||
@@ -285,9 +319,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
||||
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto parent = constantOp->getParentOp();
|
||||
rewriter.eraseOp(constantOp);
|
||||
return success();
|
||||
}
|
||||
@@ -333,7 +365,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||
auto argUser = argUses.getOwner();
|
||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
||||
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
@@ -347,7 +382,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
||||
auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
||||
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
|
||||
@@ -11,20 +11,15 @@
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <filesystem>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@@ -34,10 +29,8 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
@@ -214,11 +207,12 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
||||
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
|
||||
return;
|
||||
}
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto rowSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
||||
replacements.push_back(rowSlice.getResult());
|
||||
}
|
||||
|
||||
@@ -263,19 +257,19 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
|
||||
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||
if (!chainSet.contains(&op)
|
||||
&& !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||
return failure();
|
||||
|
||||
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||
return false;
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(),
|
||||
[](Operation* user) { return isa<spatial::SpatComputeBatch, pim::PimCoreBatchOp>(user); }))
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||
}))
|
||||
return false;
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
@@ -447,8 +441,7 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(extractSliceOp.getStaticOffsets())
|
||||
|| !hasStaticValues(extractSliceOp.getStaticSizes())
|
||||
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
|
||||
return failure();
|
||||
|
||||
@@ -510,10 +503,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
||||
return success();
|
||||
}
|
||||
|
||||
static void cloneHelperChain(Value sourceValue,
|
||||
ArrayRef<Operation*> helperChain,
|
||||
IRRewriter& rewriter,
|
||||
Value& clonedValue) {
|
||||
static void
|
||||
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
||||
IRMapping mapping;
|
||||
mapping.map(sourceValue, sourceValue);
|
||||
clonedValue = sourceValue;
|
||||
@@ -734,7 +725,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter))
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
||||
return;
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
@@ -835,7 +826,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
|
||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||
if (!storedType) {
|
||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
computeOp.emitOpError(
|
||||
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -848,10 +840,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
|
||||
|
||||
SmallVector<int64_t> destinationIndices;
|
||||
if (failed(mapIndicesThroughHelperChain(sourceIndices,
|
||||
concatReturnUse->concatShape,
|
||||
concatReturnUse->helperChain,
|
||||
destinationIndices))) {
|
||||
if (failed(mapIndicesThroughHelperChain(
|
||||
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -897,9 +887,12 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
||||
|
||||
// Replace `spat.compute` with `pim.core`
|
||||
SmallVector<Value> computeWeights;
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = PimCoreOp::create(
|
||||
rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
||||
if (!blockArg.use_empty())
|
||||
@@ -933,15 +926,19 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
||||
}
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
SmallVector<Value> batchInputs;
|
||||
if (!computeBatchOp.getInputs().empty())
|
||||
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
|
||||
|
||||
rewriter.setInsertionPointAfter(computeBatchOp);
|
||||
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
||||
computeBatchOp.getWeights(),
|
||||
computeBatchOp.getInputs());
|
||||
ValueRange(batchWeights),
|
||||
ValueRange(batchInputs));
|
||||
coreBatchOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(computeBatchOp.getWeights().size()), static_cast<int>(computeBatchOp.getInputs().size())});
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
@@ -1124,13 +1121,13 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
||||
std::string outputName = "output_" + std::to_string(index);
|
||||
rewriter.setInsertionPoint(returnOp.getParentOp());
|
||||
memref::GlobalOp::create(rewriter,
|
||||
returnOp.getLoc(),
|
||||
rewriter.getStringAttr(outputName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
{},
|
||||
{},
|
||||
{});
|
||||
returnOp.getLoc(),
|
||||
rewriter.getStringAttr(outputName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
{},
|
||||
{},
|
||||
{});
|
||||
outputTensors.push_back(
|
||||
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
||||
@@ -1210,8 +1207,9 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
markOpToRemove(computeOp);
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
if (!computeOp.getInputs().empty())
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -184,14 +184,40 @@ std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
||||
|
||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
||||
SmallVector<ComputeInstance> instances;
|
||||
auto isUsedAsWeightOnly = [](Operation* producerOp) {
|
||||
if (producerOp->getNumResults() == 0)
|
||||
return false;
|
||||
for (Value result : producerOp->getResults()) {
|
||||
if (result.use_empty())
|
||||
return false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||
if (!llvm::is_contained(compute.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||
if (!llvm::is_contained(batch.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
for (Region& region : entryOp->getRegions()) {
|
||||
for (Block& block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||
continue;
|
||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||
continue;
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
||||
@@ -582,10 +608,13 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRe
|
||||
}
|
||||
|
||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
size_t cpu = originalComputeToCpu[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(computeInstance);
|
||||
result.computeToCpuMap[computeInstance] = cpu;
|
||||
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
|
||||
result.computeToAestMap[computeInstance] = originalIndex;
|
||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
||||
}
|
||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
@@ -603,8 +632,12 @@ DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<Com
|
||||
if (scheduledTasks.empty())
|
||||
continue;
|
||||
|
||||
for (const auto& task : scheduledTasks)
|
||||
result.computeToCpuMap[computeInstances[task.nodeIndex]] = cpu;
|
||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||
ComputeInstance instance = computeInstances[task.nodeIndex];
|
||||
result.computeToCpuMap[instance] = cpu;
|
||||
result.computeToCpuSlotMap[instance] = slot;
|
||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||
}
|
||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
||||
}
|
||||
@@ -671,6 +704,16 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
}
|
||||
}
|
||||
|
||||
if (coresCount.getValue() > 0) {
|
||||
size_t schedulingCpuBudget = getSchedulingCpuBudget();
|
||||
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||
});
|
||||
if (needsExactScheduledBatches)
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
}
|
||||
|
||||
if (dcpCriticalWindowSize.getValue() == 0)
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ struct ComputeInstance {
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Analysis/TopologicalSortUtils.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
@@ -36,7 +37,6 @@
|
||||
#include "DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -61,7 +61,7 @@ static size_t getFastPathCpuBudget() {
|
||||
|
||||
static size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getFastPathCpuBudget()));
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, static_cast<size_t>(getFastPathCpuBudget())));
|
||||
}
|
||||
|
||||
static ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
@@ -129,6 +129,23 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||
|
||||
static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast<int32_t>(schedulerCpu + 1); }
|
||||
|
||||
static size_t getMaterializationCpuBudget(size_t laneCount) {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<size_t>(coresCount.getValue());
|
||||
return std::max<size_t>(1, laneCount);
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t laneCount) {
|
||||
size_t cpuBudget = getMaterializationCpuBudget(laneCount);
|
||||
assert(laneCount <= cpuBudget && "materialized batch exceeds available CPUs");
|
||||
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(laneCount);
|
||||
for (size_t laneOffset = 0; laneOffset < laneCount; ++laneOffset)
|
||||
coreIds.push_back(getPhysicalCoreId((startCpu + laneOffset) % cpuBudget));
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
|
||||
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
@@ -143,6 +160,14 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||
|
||||
static std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
if (!lhs || !rhs)
|
||||
return false;
|
||||
@@ -152,6 +177,8 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
return false;
|
||||
if (lhs.getWeights().size() != rhs.getWeights().size())
|
||||
return false;
|
||||
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
||||
return false;
|
||||
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
||||
return false;
|
||||
|
||||
@@ -841,10 +868,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
}
|
||||
|
||||
for (auto compute : group) {
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
consumed.insert(compute);
|
||||
rewriter.eraseOp(compute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
}
|
||||
|
||||
struct ComputeMotifInfo {
|
||||
@@ -1329,8 +1360,9 @@ public:
|
||||
|
||||
LazyInsertComputeResult(
|
||||
ComputeValueResults computeValueResults,
|
||||
size_t producerCpu,
|
||||
std::function<std::pair<ChannelInfo, std::function<void(InsertPoint)>>(size_t, size_t)> channelInserter)
|
||||
: computeResults(computeValueResults), channelInserter(channelInserter) {}
|
||||
: computeResults(computeValueResults), producerCpu(producerCpu), channelInserter(channelInserter) {}
|
||||
|
||||
struct ChannelOrLocalOp {
|
||||
Value data;
|
||||
@@ -1339,6 +1371,9 @@ public:
|
||||
};
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex, size_t targetCpu) {
|
||||
if (targetCpu == producerCpu)
|
||||
return {computeResults.getOuter(resultIndex), false, {}};
|
||||
|
||||
Value innerValue = computeResults.getInner(resultIndex);
|
||||
auto [channelInfo, channelSendInserter] = channelInserter(resultIndex, targetCpu);
|
||||
InsertPoint sendInsertPoint;
|
||||
@@ -1353,6 +1388,7 @@ public:
|
||||
|
||||
private:
|
||||
ComputeValueResults computeResults;
|
||||
size_t producerCpu = 0;
|
||||
std::function<std::pair<ChannelInfo, std::function<void(InsertPoint)>>(size_t, size_t)> channelInserter;
|
||||
};
|
||||
|
||||
@@ -1378,28 +1414,158 @@ public:
|
||||
mergeTriviallyConnectedComputes(getOperation());
|
||||
emitMotifProfile(getOperation());
|
||||
|
||||
func::FuncOp func = getOperation();
|
||||
Location loc = func.getLoc();
|
||||
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
||||
DenseSet<ComputeInstance> materializedInstances;
|
||||
for (size_t index = 0; index < analysisResult.dominanceOrderCompute.size(); ++index) {
|
||||
ComputeInstance currentInstance = analysisResult.dominanceOrderCompute[index];
|
||||
if (!materializedInstances.insert(currentInstance).second)
|
||||
continue;
|
||||
|
||||
size_t cpu = analysisResult.computeToCpuMap.at(currentInstance);
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(currentInstance.op)) {
|
||||
createNewBatchCompute(batch, currentInstance.laneStart, currentInstance.laneCount, cpu, analysisResult);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto scalarCompute = cast<SpatCompute>(currentInstance.op);
|
||||
auto [newCompute, computeValueResults] = createNewComputeNode(scalarCompute, cpu, analysisResult);
|
||||
newComputeNodeResults.insert({currentInstance, createLazyComputeResult(newCompute, computeValueResults, cpu)});
|
||||
}
|
||||
|
||||
DenseSet<Operation*> toEraseSet;
|
||||
for (ComputeInstance instance : analysisResult.dominanceOrderCompute)
|
||||
toEraseSet.insert(instance.op);
|
||||
|
||||
struct ScheduledTask {
|
||||
ComputeInstance key;
|
||||
Operation* sourceOp = nullptr;
|
||||
size_t cpu = 0;
|
||||
size_t slot = 0;
|
||||
size_t order = 0;
|
||||
};
|
||||
struct ChannelInfo {
|
||||
int64_t channelId = -1;
|
||||
int32_t sourceCoreId = -1;
|
||||
int32_t targetCoreId = -1;
|
||||
};
|
||||
struct CpuProgram {
|
||||
SpatCompute op;
|
||||
Block* block = nullptr;
|
||||
DenseMap<Value, Value> externalInputMap;
|
||||
DenseMap<Value, size_t> weightToIndex;
|
||||
};
|
||||
|
||||
auto getTaskInputs = [&](const ScheduledTask& task) {
|
||||
SmallVector<Value> inputs;
|
||||
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp)) {
|
||||
llvm::append_range(inputs, compute.getInputs());
|
||||
return inputs;
|
||||
}
|
||||
auto batch = cast<SpatComputeBatch>(task.sourceOp);
|
||||
for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane)
|
||||
if (!batch.getInputs().empty())
|
||||
inputs.push_back(batch.getInputs()[lane]);
|
||||
return inputs;
|
||||
};
|
||||
|
||||
auto getTaskWeights = [&](const ScheduledTask& task) {
|
||||
SmallVector<Value> weights;
|
||||
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp)) {
|
||||
llvm::append_range(weights, compute.getWeights());
|
||||
return weights;
|
||||
}
|
||||
auto batch = cast<SpatComputeBatch>(task.sourceOp);
|
||||
for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane)
|
||||
weights.push_back(batch.getWeights()[lane]);
|
||||
return weights;
|
||||
};
|
||||
|
||||
auto getTaskOutputValues = [&](const ScheduledTask& task) {
|
||||
SmallVector<Value> outputs;
|
||||
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp)) {
|
||||
for (Value result : compute.getResults())
|
||||
outputs.push_back(result);
|
||||
return outputs;
|
||||
}
|
||||
auto batch = cast<SpatComputeBatch>(task.sourceOp);
|
||||
for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane)
|
||||
if (!batch.getOutputs().empty())
|
||||
outputs.push_back(batch.getOutputs()[lane]);
|
||||
return outputs;
|
||||
};
|
||||
|
||||
auto getTaskOutputTypes = [&](const ScheduledTask& task) {
|
||||
SmallVector<Type> resultTypes;
|
||||
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp)) {
|
||||
llvm::append_range(resultTypes, compute.getResultTypes());
|
||||
return resultTypes;
|
||||
}
|
||||
auto batch = cast<SpatComputeBatch>(task.sourceOp);
|
||||
for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane)
|
||||
if (!batch.getOutputs().empty())
|
||||
resultTypes.push_back(batch.getOutputs()[lane].getType());
|
||||
return resultTypes;
|
||||
};
|
||||
|
||||
auto getTaskTemplateBlock = [&](const ScheduledTask& task) -> Block& {
|
||||
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp))
|
||||
return compute.getBody().front();
|
||||
return cast<SpatComputeBatch>(task.sourceOp).getBody().front();
|
||||
};
|
||||
|
||||
auto appendUniqueValue = [](SmallVectorImpl<Value>& values, DenseSet<Value>& seen, Value value) {
|
||||
if (seen.insert(value).second)
|
||||
values.push_back(value);
|
||||
};
|
||||
|
||||
DenseMap<ComputeInstance, ScheduledTask> taskByKey;
|
||||
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||
SmallVector<size_t> orderedCpus;
|
||||
DenseSet<size_t> seenCpus;
|
||||
DenseSet<Operation*> internalInputOpsToErase;
|
||||
DenseMap<Operation*, bool> isInternalInputOpCache;
|
||||
size_t nextOrder = 0;
|
||||
auto markCpuSeen = [&](size_t cpu) {
|
||||
if (seenCpus.insert(cpu).second)
|
||||
orderedCpus.push_back(cpu);
|
||||
};
|
||||
for (ComputeInstance scheduledInstance : analysisResult.dominanceOrderCompute) {
|
||||
size_t cpu = analysisResult.computeToCpuMap.at(scheduledInstance);
|
||||
ScheduledTask task {scheduledInstance,
|
||||
scheduledInstance.op,
|
||||
cpu,
|
||||
analysisResult.computeToCpuSlotMap.lookup(scheduledInstance),
|
||||
nextOrder++};
|
||||
taskByKey[task.key] = task;
|
||||
tasksByCpu[cpu].push_back(task);
|
||||
markCpuSeen(cpu);
|
||||
}
|
||||
|
||||
llvm::sort(orderedCpus);
|
||||
for (size_t cpu : orderedCpus) {
|
||||
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
||||
if (lhs.slot != rhs.slot)
|
||||
return lhs.slot < rhs.slot;
|
||||
return lhs.order < rhs.order;
|
||||
});
|
||||
}
|
||||
|
||||
std::function<bool(Operation*)> isInternalInputOp = [&](Operation* op) {
|
||||
auto it = isInternalInputOpCache.find(op);
|
||||
if (it != isInternalInputOpCache.end())
|
||||
return it->second;
|
||||
|
||||
auto extract = dyn_cast_or_null<tensor::ExtractSliceOp>(op);
|
||||
if (!extract)
|
||||
return isInternalInputOpCache[op] = false;
|
||||
|
||||
for (Value result : extract->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (toEraseSet.contains(user))
|
||||
continue;
|
||||
if (isInternalInputOp(user))
|
||||
continue;
|
||||
return isInternalInputOpCache[op] = false;
|
||||
}
|
||||
}
|
||||
return isInternalInputOpCache[op] = true;
|
||||
};
|
||||
|
||||
auto collectInternalInputOps = [&](Value value) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op)) {
|
||||
if (isInternalInputOp(extract.getOperation()))
|
||||
internalInputOpsToErase.insert(extract.getOperation());
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
}
|
||||
};
|
||||
|
||||
DenseSet<Operation*> externalUsersToMove;
|
||||
auto collectExternalUsers = [&](Operation* op, auto&& collectExternalUsers) -> void {
|
||||
if (!externalUsersToMove.insert(op).second)
|
||||
@@ -1413,28 +1579,294 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
DenseSet<Operation*> erasedOps;
|
||||
for (ComputeInstance instance : llvm::reverse(analysisResult.dominanceOrderCompute)) {
|
||||
if (!erasedOps.insert(instance.op).second)
|
||||
continue;
|
||||
Operation* oldOp = instance.op;
|
||||
if (Operation* newOp = oldToNewOpMap.lookup(oldOp)) {
|
||||
for (unsigned i = 0; i < oldOp->getNumResults(); ++i) {
|
||||
for (auto& use : llvm::make_early_inc_range(oldOp->getResult(i).getUses())) {
|
||||
DenseMap<ComputeInstance, SmallVector<SmallVector<ChannelInfo>>> remoteSendsByTask;
|
||||
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
||||
DenseMap<size_t, SmallVector<ProducerValueRef>> cpuExternalOutputs;
|
||||
DenseMap<size_t, DenseSet<Value>> seenExternalInputsByCpu;
|
||||
DenseMap<size_t, DenseSet<Value>> seenWeightsByCpu;
|
||||
|
||||
for (size_t cpu : orderedCpus) {
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
auto taskWeights = getTaskWeights(task);
|
||||
for (Value weight : taskWeights)
|
||||
appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight);
|
||||
|
||||
auto taskInputs = getTaskInputs(task);
|
||||
auto& remoteInputs = remoteInputsByTask[task.key];
|
||||
remoteInputs.resize(taskInputs.size());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
auto producerRef = getProducerValueRef(input);
|
||||
if (producerRef) {
|
||||
collectInternalInputOps(input);
|
||||
auto producerIt = taskByKey.find(producerRef->instance);
|
||||
if (producerIt != taskByKey.end()) {
|
||||
if (producerIt->second.cpu != cpu) {
|
||||
ChannelInfo info {
|
||||
nextChannelId++,
|
||||
getPhysicalCoreId(producerIt->second.cpu),
|
||||
getPhysicalCoreId(cpu),
|
||||
};
|
||||
remoteInputs[inputIndex] = info;
|
||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||
if (perResultChannels.empty())
|
||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
|
||||
perResultChannels[producerRef->resultIndex].push_back(info);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input);
|
||||
}
|
||||
|
||||
auto taskOutputs = getTaskOutputValues(task);
|
||||
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||
bool hasExternalUser = false;
|
||||
for (auto& use : output.getUses()) {
|
||||
Operation* useOwner = use.getOwner();
|
||||
if (!toEraseSet.contains(useOwner)) {
|
||||
use.assign(newOp->getResult(i));
|
||||
if (!isa<func::ReturnOp>(useOwner) && useOwner->isBeforeInBlock(newOp))
|
||||
collectExternalUsers(useOwner, collectExternalUsers);
|
||||
if (toEraseSet.contains(useOwner))
|
||||
continue;
|
||||
hasExternalUser = true;
|
||||
if (!isa<func::ReturnOp>(useOwner))
|
||||
collectExternalUsers(useOwner, collectExternalUsers);
|
||||
}
|
||||
if (hasExternalUser)
|
||||
cpuExternalOutputs[cpu].push_back({task.key, resultIndex});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
|
||||
IRRewriter rewriter(&getContext());
|
||||
DenseMap<size_t, CpuProgram> cpuPrograms;
|
||||
DenseMap<Value, Value> oldToNewExternalValueMap;
|
||||
|
||||
for (size_t cpu : orderedCpus) {
|
||||
SmallVector<Value> operands;
|
||||
operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size());
|
||||
llvm::append_range(operands, cpuWeights[cpu]);
|
||||
llvm::append_range(operands, cpuExternalInputs[cpu]);
|
||||
|
||||
SmallVector<Type> resultTypes;
|
||||
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||
resultTypes.push_back(getTaskOutputTypes(task)[outputRef.resultIndex]);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(cpuWeights[cpu].size()), static_cast<int>(cpuExternalInputs[cpu].size())});
|
||||
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(cpu)));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(cpuExternalInputs[cpu].size());
|
||||
blockArgLocs.reserve(cpuExternalInputs[cpu].size());
|
||||
for (Value input : cpuExternalInputs[cpu]) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
Block* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
CpuProgram program;
|
||||
program.op = newCompute;
|
||||
program.block = newBlock;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu]))
|
||||
program.weightToIndex[weight] = weightIndex;
|
||||
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
||||
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||
oldToNewExternalValueMap[getTaskOutputValues(task)[outputRef.resultIndex]] = newCompute.getResult(resultIndex);
|
||||
}
|
||||
cpuPrograms[cpu] = std::move(program);
|
||||
}
|
||||
|
||||
DenseMap<ComputeInstance, SmallVector<Value>> producedValuesByTask;
|
||||
for (size_t cpu : orderedCpus) {
|
||||
CpuProgram& program = cpuPrograms[cpu];
|
||||
IRRewriter cpuRewriter(&getContext());
|
||||
cpuRewriter.setInsertionPointToEnd(program.block);
|
||||
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
SmallVector<Value> taskInputs = getTaskInputs(task);
|
||||
auto taskWeights = getTaskWeights(task);
|
||||
Block& templateBlock = getTaskTemplateBlock(task);
|
||||
|
||||
SmallVector<Value> resolvedInputs;
|
||||
resolvedInputs.reserve(taskInputs.size());
|
||||
auto remoteInputsIt = remoteInputsByTask.find(task.key);
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
auto producerRef = getProducerValueRef(input);
|
||||
if (producerRef) {
|
||||
auto producerIt = taskByKey.find(producerRef->instance);
|
||||
if (producerIt != taskByKey.end()) {
|
||||
if (producerIt->second.cpu == cpu) {
|
||||
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||
task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
||||
<< " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot
|
||||
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||
continue;
|
||||
}
|
||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||
auto receive =
|
||||
spatial::SpatChannelReceiveOp::create(cpuRewriter,
|
||||
loc,
|
||||
input.getType(),
|
||||
cpuRewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
cpuRewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
cpuRewriter.getI32IntegerAttr(channelInfo.targetCoreId));
|
||||
resolvedInputs.push_back(receive.getResult());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
resolvedInputs.push_back(program.externalInputMap.at(input));
|
||||
}
|
||||
|
||||
SmallVector<Value> taskYieldValues;
|
||||
cpuRewriter.setInsertionPointToEnd(program.block);
|
||||
if (isa<SpatCompute>(task.sourceOp)) {
|
||||
IRMapping mapper;
|
||||
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
|
||||
mapper.map(oldArg, resolvedInputs[argIndex]);
|
||||
|
||||
for (Operation& op : templateBlock) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
||||
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
||||
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) {
|
||||
IRMapping mapper;
|
||||
if (templateBlock.getNumArguments() == 1)
|
||||
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
||||
|
||||
for (Operation& op : templateBlock) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
||||
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
||||
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
producedValuesByTask[task.key] = taskYieldValues;
|
||||
if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) {
|
||||
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
||||
if (sendInfos.empty())
|
||||
continue;
|
||||
Value producedValue = taskYieldValues[resultIndex];
|
||||
for (const ChannelInfo& sendInfo : sendInfos)
|
||||
spatial::SpatChannelSendOp::create(cpuRewriter,
|
||||
loc,
|
||||
cpuRewriter.getI64IntegerAttr(sendInfo.channelId),
|
||||
cpuRewriter.getI32IntegerAttr(sendInfo.sourceCoreId),
|
||||
cpuRewriter.getI32IntegerAttr(sendInfo.targetCoreId),
|
||||
producedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
oldOp->erase();
|
||||
|
||||
SmallVector<Value> yieldValues;
|
||||
yieldValues.reserve(cpuExternalOutputs[cpu].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||
task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||
<< " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart;
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||
}
|
||||
spatial::SpatYieldOp::create(cpuRewriter, loc, ValueRange(yieldValues));
|
||||
}
|
||||
|
||||
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
||||
for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
|
||||
if (!toEraseSet.contains(use.getOwner()))
|
||||
use.assign(newValue);
|
||||
}
|
||||
|
||||
DenseSet<Operation*> allOpsToErase = toEraseSet;
|
||||
for (Operation* op : internalInputOpsToErase)
|
||||
allOpsToErase.insert(op);
|
||||
|
||||
SmallVector<Operation*> orderedOpsToErase;
|
||||
for (Operation& op : func.getBody().front())
|
||||
if (allOpsToErase.contains(&op))
|
||||
orderedOpsToErase.push_back(&op);
|
||||
for (Operation* op : llvm::reverse(orderedOpsToErase)) {
|
||||
SmallVector<Operation*> remainingUsers;
|
||||
for (Value result : op->getResults())
|
||||
for (Operation* user : result.getUsers())
|
||||
remainingUsers.push_back(user);
|
||||
if (!remainingUsers.empty()) {
|
||||
llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n";
|
||||
llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n";
|
||||
op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
|
||||
llvm::errs() << "\n";
|
||||
for (Operation* user : remainingUsers) {
|
||||
llvm::errs() << " user: " << user->getName()
|
||||
<< " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n";
|
||||
user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
|
||||
llvm::errs() << "\n";
|
||||
}
|
||||
op->emitOpError("still has uses during per-cpu merge cleanup");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
op->erase();
|
||||
}
|
||||
|
||||
func::FuncOp func = getOperation();
|
||||
auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
|
||||
SmallVector<Operation*> orderedUsersToMove;
|
||||
for (Operation& op : func.getBody().front()) {
|
||||
if (&op == returnOp.getOperation())
|
||||
@@ -1445,9 +1877,13 @@ public:
|
||||
for (Operation* op : orderedUsersToMove)
|
||||
op->moveBefore(returnOp);
|
||||
|
||||
sinkChannelsIntoComputes(func, nextChannelId);
|
||||
rebatchEquivalentComputes(func, nextChannelId);
|
||||
compactScalarChannelRuns(func, nextChannelId);
|
||||
if (!sortTopologically(&func.getBody().front())) {
|
||||
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||
generateReport(func, "spatial1_dcp_merged_report", analysisResult.cpuToLastComputeMap.size());
|
||||
}
|
||||
@@ -1477,16 +1913,18 @@ private:
|
||||
Value resolvedInput = input;
|
||||
if (auto producerRef = getProducerValueRef(input)) {
|
||||
LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance);
|
||||
auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu);
|
||||
(void) isChannel;
|
||||
(void) channelVal;
|
||||
resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
input.getType(),
|
||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId))
|
||||
.getResult();
|
||||
auto [channelVal, isChannel, channelInfo] =
|
||||
producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu);
|
||||
if (isChannel)
|
||||
resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
input.getType(),
|
||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId))
|
||||
.getResult();
|
||||
else
|
||||
resolvedInput = channelVal;
|
||||
}
|
||||
|
||||
newComputeOperands.push_back(resolvedInput);
|
||||
@@ -1532,7 +1970,8 @@ private:
|
||||
uint32_t firstLane,
|
||||
uint32_t laneCount,
|
||||
size_t currentCpu,
|
||||
const DCPAnalysisResult& analysisResult) {
|
||||
const DCPAnalysisResult& analysisResult,
|
||||
std::optional<uint64_t> rebatchPhase = std::nullopt) {
|
||||
func::FuncOp func = getOperation();
|
||||
auto loc = func.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
@@ -1547,24 +1986,29 @@ private:
|
||||
|
||||
for (uint32_t lane = firstLane; lane < firstLane + laneCount; ++lane) {
|
||||
weights.push_back(batch.getWeights()[lane]);
|
||||
resultTypes.push_back(batch.getOutputs()[lane].getType());
|
||||
if (!batch.getOutputs().empty())
|
||||
resultTypes.push_back(batch.getOutputs()[lane].getType());
|
||||
|
||||
Value input = batch.getInputs()[lane];
|
||||
Value resolvedInput = input;
|
||||
if (auto producerRef = getProducerValueRef(input)) {
|
||||
LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance);
|
||||
auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu);
|
||||
(void) isChannel;
|
||||
(void) channelVal;
|
||||
resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
input.getType(),
|
||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId))
|
||||
.getResult();
|
||||
if (!batch.getInputs().empty()) {
|
||||
Value input = batch.getInputs()[lane];
|
||||
Value resolvedInput = input;
|
||||
if (auto producerRef = getProducerValueRef(input)) {
|
||||
LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance);
|
||||
auto [channelVal, isChannel, channelInfo] =
|
||||
producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu);
|
||||
if (isChannel)
|
||||
resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
input.getType(),
|
||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId))
|
||||
.getResult();
|
||||
else
|
||||
resolvedInput = channelVal;
|
||||
}
|
||||
inputs.push_back(resolvedInput);
|
||||
}
|
||||
inputs.push_back(resolvedInput);
|
||||
}
|
||||
|
||||
Block& templateBlock = batch.getBody().front();
|
||||
@@ -1574,11 +2018,17 @@ private:
|
||||
compute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||
compute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(currentCpu)));
|
||||
if (rebatchPhase)
|
||||
compute->setAttr(kRebatchPhaseAttrName, rewriter.getI64IntegerAttr(*rebatchPhase));
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&compute.getBody(), compute.getBody().end(), TypeRange {templateBlock.getArgument(0).getType()}, {loc});
|
||||
SmallVector<Type> blockArgTypes;
|
||||
if (templateBlock.getNumArguments() == 1)
|
||||
blockArgTypes.push_back(templateBlock.getArgument(0).getType());
|
||||
SmallVector<Location> blockArgLocs(templateBlock.getNumArguments(), loc);
|
||||
auto* newBlock = rewriter.createBlock(&compute.getBody(), compute.getBody().end(), blockArgTypes, blockArgLocs);
|
||||
IRMapping mapper;
|
||||
mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0));
|
||||
if (templateBlock.getNumArguments() == 1)
|
||||
mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0));
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
@@ -1600,14 +2050,16 @@ private:
|
||||
ValueRange(weights),
|
||||
ValueRange(inputs));
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdAttrName,
|
||||
rewriter.getDenseI32ArrayAttr(SmallVector<int32_t>(laneCount, getPhysicalCoreId(currentCpu))));
|
||||
rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount)));
|
||||
|
||||
auto* newBlock = rewriter.createBlock(&rebatched.getBody(),
|
||||
rebatched.getBody().end(),
|
||||
TypeRange {templateBlock.getArgument(0).getType()},
|
||||
SmallVector<Location>(1, loc));
|
||||
SmallVector<Type> blockArgTypes;
|
||||
if (templateBlock.getNumArguments() == 1)
|
||||
blockArgTypes.push_back(templateBlock.getArgument(0).getType());
|
||||
SmallVector<Location> blockArgLocs(templateBlock.getNumArguments(), loc);
|
||||
auto* newBlock = rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), blockArgTypes, blockArgLocs);
|
||||
IRMapping mapper;
|
||||
mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0));
|
||||
if (templateBlock.getNumArguments() == 1)
|
||||
mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0));
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
@@ -1621,7 +2073,7 @@ private:
|
||||
ComputeValueResults results;
|
||||
results.outerValues.assign(rebatched->result_begin(), rebatched->result_end());
|
||||
results.innerValues = results.outerValues;
|
||||
if (results.innerValues.empty())
|
||||
if (results.innerValues.empty() && yieldOp.getNumOperands() == 1)
|
||||
results.innerValues.push_back(yieldOp.getOperand(0));
|
||||
newComputeNodeResults.insert({
|
||||
ComputeInstance {batch.getOperation(), firstLane, laneCount},
|
||||
@@ -1658,7 +2110,7 @@ private:
|
||||
channelInfo, insertVal};
|
||||
return ret;
|
||||
};
|
||||
return LazyInsertComputeResult(computeValueResults, insertNew);
|
||||
return LazyInsertComputeResult(computeValueResults, producerCpu, insertNew);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from colorama import Fore, Style
|
||||
@@ -37,8 +38,12 @@ def _parse_pim_pass_timings(output_text):
|
||||
return pass_timings
|
||||
|
||||
|
||||
def _format_command(cmd):
|
||||
return shlex.join(str(arg) for arg in cmd)
|
||||
|
||||
|
||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
||||
crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None):
|
||||
# Define the arguments, with the possibility to set crossbar size and count
|
||||
args = [
|
||||
network_path,
|
||||
@@ -51,10 +56,18 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||
f"--crossbar-count={crossbar_count}",
|
||||
"--enable-timing",
|
||||
]
|
||||
if core_count is not None:
|
||||
args.append(f"--core-count={core_count}")
|
||||
|
||||
cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args]
|
||||
if reporter is not None:
|
||||
reporter.log(f" Raptor command: {_format_command(cmd)}")
|
||||
else:
|
||||
print(f"Raptor command: {_format_command(cmd)}")
|
||||
|
||||
try:
|
||||
output_text = run_command_with_reporter(
|
||||
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
reporter=reporter,
|
||||
capture_output=True,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import shlex
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -11,12 +10,6 @@ from validate_one import ProgressReporter, clean_workspace_artifacts, validate_n
|
||||
from raptor import PIM_PASS_LABELS
|
||||
|
||||
|
||||
def format_command(cmd):
|
||||
if isinstance(cmd, (list, tuple)):
|
||||
return shlex.join(str(arg) for arg in cmd)
|
||||
return str(cmd)
|
||||
|
||||
|
||||
def format_return_status(returncode):
|
||||
if returncode < 0:
|
||||
signal_num = -returncode
|
||||
@@ -34,8 +27,6 @@ def print_validation_error(reporter, rel, exc):
|
||||
file=sys.stderr, flush=True)
|
||||
if isinstance(exc, subprocess.CalledProcessError):
|
||||
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
||||
print("Retry command:", file=sys.stderr, flush=True)
|
||||
print(format_command(exc.cmd), file=sys.stderr, flush=True)
|
||||
else:
|
||||
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
||||
print("=" * 72, file=sys.stderr, flush=True)
|
||||
@@ -65,6 +56,8 @@ def main():
|
||||
ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.")
|
||||
ap.add_argument("--crossbar-size", type=int, default=64)
|
||||
ap.add_argument("--crossbar-count", type=int, default=8)
|
||||
ap.add_argument("--core-count", type=int, default=None,
|
||||
help="Core count to pass to Raptor. If omitted, Raptor uses its default.")
|
||||
ap.add_argument("--clean", action="store_true",
|
||||
help="Remove generated validation artifacts under each model workspace and exit.")
|
||||
a = ap.parse_args()
|
||||
@@ -114,7 +107,7 @@ def main():
|
||||
try:
|
||||
result = validate_network(
|
||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, core_count=a.core_count,
|
||||
threshold=a.threshold,
|
||||
reporter=reporter,
|
||||
model_index=index,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import subprocess
|
||||
@@ -258,7 +257,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
|
||||
|
||||
|
||||
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3,
|
||||
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3,
|
||||
reporter=None, model_index=1, model_total=1):
|
||||
network_onnx_path = Path(network_onnx_path).resolve()
|
||||
raptor_path = Path(raptor_path).resolve()
|
||||
@@ -313,7 +312,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||
pim_pass_timings = compile_with_raptor(
|
||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||
crossbar_size, crossbar_count,
|
||||
crossbar_size, crossbar_count, core_count=core_count,
|
||||
cwd=raptor_dir, reporter=reporter)
|
||||
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
||||
reporter.advance()
|
||||
@@ -350,18 +349,3 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||
reporter.log("=" * 72)
|
||||
if owns_reporter:
|
||||
reporter.finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--network-onnx", required=True)
|
||||
ap.add_argument("--raptor-path", required=True)
|
||||
ap.add_argument("--onnx-include-dir", required=True)
|
||||
a = ap.parse_args()
|
||||
|
||||
simulator_dir = Path(__file__).parent.resolve() / ".." / "backend-simulators" / "pim" / "pim-simulator"
|
||||
|
||||
passed = validate_network(
|
||||
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
|
||||
)
|
||||
raise SystemExit(0 if passed.passed else 1)
|
||||
|
||||
Reference in New Issue
Block a user