Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into main
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m25s
All checks were successful
Validate Operations / validate-operations (push) Successful in 24m25s
This commit is contained in:
@@ -135,7 +135,7 @@ validate.py \
|
|||||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
--onnx-include-dir ../onnx-mlir/include \
|
--onnx-include-dir ../onnx-mlir/include \
|
||||||
--operations-dir ./networks/yolo11n/depth_04 \
|
--operations-dir ./networks/yolo11n/depth_04 \
|
||||||
--crossbar-size 2048 --crossbar-count 256
|
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||||
```
|
```
|
||||||
|
|
||||||
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id";
|
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId";
|
||||||
|
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds";
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -634,8 +634,8 @@ static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName);
|
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||||
assert(coreIdsAttr && "pim.core_batch requires core_id array attribute");
|
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,21 +23,42 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||||
|
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||||
|
unsigned inputCount = compute.getInputs().size();
|
||||||
|
if (inputCount == 0)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
unsigned inputBegin = compute->getNumOperands() - inputCount;
|
||||||
|
if (operandNumber < inputBegin)
|
||||||
|
return std::nullopt;
|
||||||
|
return operandNumber - inputBegin;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner)) {
|
||||||
|
unsigned inputCount = computeBatch.getInputs().size();
|
||||||
|
if (inputCount == 0)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
unsigned inputBegin = computeBatch->getNumOperands() - inputCount;
|
||||||
|
if (operandNumber < inputBegin)
|
||||||
|
return std::nullopt;
|
||||||
|
return operandNumber - inputBegin;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
||||||
Location loc = extractSliceOp.getLoc();
|
|
||||||
|
|
||||||
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
for (auto& uses : extractSliceOp->getUses()) {
|
for (auto& uses : extractSliceOp->getUses()) {
|
||||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||||
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
|
if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber()))
|
||||||
if (spatCompute.getInputs().empty())
|
|
||||||
return failure();
|
|
||||||
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
||||||
@@ -50,7 +71,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||||
|
|
||||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||||
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
|
||||||
if (BBArgValue.use_empty())
|
if (BBArgValue.use_empty())
|
||||||
@@ -69,7 +93,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
|||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||||
auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
|
|
||||||
if (BBArgValue.use_empty())
|
if (BBArgValue.use_empty())
|
||||||
@@ -165,8 +192,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
auto constUsers = constUses.getOwner();
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
|
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||||
@@ -183,8 +212,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||||
|
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||||
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||||
@@ -240,8 +271,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
auto constUsers = constUses.getOwner();
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
|
auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
@@ -253,8 +286,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||||
|
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||||
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
@@ -265,8 +300,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
||||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||||
}
|
}
|
||||||
else {
|
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||||
if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
|
||||||
if (!mapSpatComputeToConst.contains(parent)) {
|
if (!mapSpatComputeToConst.contains(parent)) {
|
||||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
@@ -286,8 +320,6 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
auto parent = constantOp->getParentOp();
|
|
||||||
rewriter.eraseOp(constantOp);
|
rewriter.eraseOp(constantOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -333,7 +365,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||||
auto argUser = argUses.getOwner();
|
auto argUser = argUses.getOwner();
|
||||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
||||||
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
@@ -347,7 +382,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
}
|
}
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
||||||
auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber());
|
||||||
|
if (!inputIndex)
|
||||||
|
return failure();
|
||||||
|
auto BBArgIndex = *inputIndex;
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
|||||||
@@ -11,20 +11,15 @@
|
|||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/LogicalResult.h"
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <filesystem>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@@ -34,10 +29,8 @@
|
|||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
@@ -118,7 +111,7 @@ static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t&
|
|||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
|
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
@@ -185,6 +178,43 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
|
|||||||
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
|
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
|
||||||
|
int32_t laneCount,
|
||||||
|
IRMapping& mapper,
|
||||||
|
IRRewriter& rewriter) {
|
||||||
|
auto targetCoreIds = sendManyBatchOp.getTargetCoreIds();
|
||||||
|
for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) {
|
||||||
|
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
|
||||||
|
auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount);
|
||||||
|
pim::PimSendBatchOp::create(rewriter,
|
||||||
|
sendManyBatchOp.getLoc(),
|
||||||
|
mapper.lookup(input),
|
||||||
|
getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetSlice));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
||||||
|
int32_t laneCount,
|
||||||
|
IRMapping& mapper,
|
||||||
|
IRRewriter& rewriter) {
|
||||||
|
auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds();
|
||||||
|
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
|
||||||
|
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
|
||||||
|
auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount);
|
||||||
|
auto outputType = cast<ShapedType>(output.getType());
|
||||||
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType);
|
||||||
|
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||||
|
receiveManyBatchOp.getLoc(),
|
||||||
|
outputBuffer.getType(),
|
||||||
|
outputBuffer,
|
||||||
|
getTensorSizeInBytesAttr(rewriter, output),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceSlice))
|
||||||
|
.getOutput();
|
||||||
|
mapper.map(output, received);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
||||||
Value input = extractRowsOp.getInput();
|
Value input = extractRowsOp.getInput();
|
||||||
RankedTensorType inputType;
|
RankedTensorType inputType;
|
||||||
@@ -214,11 +244,12 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
|||||||
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
|
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
|
||||||
|
rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto rowSlice = tensor::ExtractSliceOp::create(
|
auto rowSlice =
|
||||||
rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
||||||
replacements.push_back(rowSlice.getResult());
|
replacements.push_back(rowSlice.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,6 +263,56 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
|||||||
rewriter.replaceOp(concatOp, concatenated);
|
rewriter.replaceOp(concatOp, concatenated);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
|
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
|
||||||
|
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
|
||||||
|
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||||
|
wvmmOps.push_back(wvmmOp);
|
||||||
|
});
|
||||||
|
|
||||||
|
for (auto wvmmOp : wvmmOps) {
|
||||||
|
rewriter.setInsertionPoint(wvmmOp);
|
||||||
|
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
|
||||||
|
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
|
||||||
|
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
|
||||||
|
wvmmOp.getOutput().getType(),
|
||||||
|
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
|
||||||
|
wvmmOp.getInput(),
|
||||||
|
outputBuffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
|
SmallVector<spatial::SpatMapOp> mapOps;
|
||||||
|
funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); });
|
||||||
|
|
||||||
|
for (auto mapOp : mapOps) {
|
||||||
|
Block& body = mapOp.getBody().front();
|
||||||
|
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
|
||||||
|
|
||||||
|
SmallVector<Value> replacements;
|
||||||
|
replacements.reserve(mapOp.getInputs().size());
|
||||||
|
rewriter.setInsertionPoint(mapOp);
|
||||||
|
for (Value input : mapOp.getInputs()) {
|
||||||
|
IRMapping mapping;
|
||||||
|
mapping.map(body.getArgument(0), input);
|
||||||
|
|
||||||
|
Value replacement = input;
|
||||||
|
for (Operation& op : body.without_terminator()) {
|
||||||
|
Operation* cloned = rewriter.clone(op, mapping);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||||
|
mapping.map(originalResult, clonedResult);
|
||||||
|
rewriter.setInsertionPointAfter(cloned);
|
||||||
|
}
|
||||||
|
|
||||||
|
replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||||
|
replacements.push_back(replacement);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(mapOp, replacements);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||||
SmallVectorImpl<Operation*>& helperChain,
|
SmallVectorImpl<Operation*>& helperChain,
|
||||||
bool requireReturnUse = true) {
|
bool requireReturnUse = true) {
|
||||||
@@ -263,19 +344,19 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
|||||||
|
|
||||||
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||||
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||||
if (!chainSet.contains(&op)
|
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||||
&& !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||||
return false;
|
return false;
|
||||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(),
|
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||||
[](Operation* user) { return isa<spatial::SpatComputeBatch, pim::PimCoreBatchOp>(user); }))
|
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
||||||
|
}))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
Block& block = computeOp.getBody().front();
|
Block& block = computeOp.getBody().front();
|
||||||
@@ -447,8 +528,7 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
|||||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||||
};
|
};
|
||||||
if (!hasStaticValues(extractSliceOp.getStaticOffsets())
|
if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes())
|
||||||
|| !hasStaticValues(extractSliceOp.getStaticSizes())
|
|
||||||
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
|
|| !hasStaticValues(extractSliceOp.getStaticStrides()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -510,10 +590,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void cloneHelperChain(Value sourceValue,
|
static void
|
||||||
ArrayRef<Operation*> helperChain,
|
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
||||||
IRRewriter& rewriter,
|
|
||||||
Value& clonedValue) {
|
|
||||||
IRMapping mapping;
|
IRMapping mapping;
|
||||||
mapping.map(sourceValue, sourceValue);
|
mapping.map(sourceValue, sourceValue);
|
||||||
clonedValue = sourceValue;
|
clonedValue = sourceValue;
|
||||||
@@ -560,6 +638,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
func::FuncOp funcOp = *entryFunc;
|
func::FuncOp funcOp = *entryFunc;
|
||||||
|
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
|
expandMapOps(funcOp, rewriter);
|
||||||
|
|
||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
target.addLegalDialect<PimDialect,
|
target.addLegalDialect<PimDialect,
|
||||||
@@ -649,6 +728,32 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
for (auto extractRowsOp : extractRowsOps)
|
for (auto extractRowsOp : extractRowsOps)
|
||||||
lowerExtractRows(extractRowsOp, rewriter);
|
lowerExtractRows(extractRowsOp, rewriter);
|
||||||
|
|
||||||
|
{
|
||||||
|
RewritePatternSet coreBodyPatterns(ctx);
|
||||||
|
populateWithGenerated(coreBodyPatterns);
|
||||||
|
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||||
|
|
||||||
|
SmallVector<pim::PimCoreOp> coreOps;
|
||||||
|
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||||
|
for (auto coreOp : coreOps) {
|
||||||
|
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||||
|
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||||
|
for (auto coreBatchOp : coreBatchOps) {
|
||||||
|
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerRemainingSpatialMathOps(funcOp, rewriter);
|
||||||
|
|
||||||
RewritePatternSet channelPatterns(ctx);
|
RewritePatternSet channelPatterns(ctx);
|
||||||
populateWithGenerated(channelPatterns);
|
populateWithGenerated(channelPatterns);
|
||||||
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||||
@@ -734,7 +839,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|
||||||
if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter))
|
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
SmallVector<Operation*> helperChain;
|
SmallVector<Operation*> helperChain;
|
||||||
@@ -835,7 +940,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
|
|
||||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||||
if (!storedType) {
|
if (!storedType) {
|
||||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
computeOp.emitOpError(
|
||||||
|
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -848,10 +954,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
|
sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx;
|
||||||
|
|
||||||
SmallVector<int64_t> destinationIndices;
|
SmallVector<int64_t> destinationIndices;
|
||||||
if (failed(mapIndicesThroughHelperChain(sourceIndices,
|
if (failed(mapIndicesThroughHelperChain(
|
||||||
concatReturnUse->concatShape,
|
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||||
concatReturnUse->helperChain,
|
|
||||||
destinationIndices))) {
|
|
||||||
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
@@ -897,9 +1001,12 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
|
||||||
|
|
||||||
// Replace `spat.compute` with `pim.core`
|
// Replace `spat.compute` with `pim.core`
|
||||||
|
SmallVector<Value> computeWeights;
|
||||||
|
if (!computeOp.getWeights().empty())
|
||||||
|
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
auto coreOp = PimCoreOp::create(
|
auto coreOp = PimCoreOp::create(
|
||||||
rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
||||||
if (!blockArg.use_empty())
|
if (!blockArg.use_empty())
|
||||||
@@ -933,16 +1040,20 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||||
|
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||||
|
SmallVector<Value> batchInputs;
|
||||||
|
if (!computeBatchOp.getInputs().empty())
|
||||||
|
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeBatchOp);
|
rewriter.setInsertionPointAfter(computeBatchOp);
|
||||||
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
|
||||||
computeBatchOp.getWeights(),
|
ValueRange(batchWeights),
|
||||||
computeBatchOp.getInputs());
|
ValueRange(batchInputs));
|
||||||
coreBatchOp.getProperties().setOperandSegmentSizes(
|
coreBatchOp.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(computeBatchOp.getWeights().size()), static_cast<int>(computeBatchOp.getInputs().size())});
|
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||||
|
|
||||||
SmallVector<Type> blockArgTypes;
|
SmallVector<Type> blockArgTypes;
|
||||||
SmallVector<Location> blockArgLocs;
|
SmallVector<Location> blockArgLocs;
|
||||||
@@ -1003,6 +1114,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto sendManyBatchOp = dyn_cast<spatial::SpatChannelSendManyBatchOp>(op)) {
|
||||||
|
lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||||
@@ -1017,6 +1133,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto receiveManyBatchOp = dyn_cast<spatial::SpatChannelReceiveManyBatchOp>(op)) {
|
||||||
|
lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||||
Operation* cloned = rewriter.clone(op, mapper);
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
@@ -1210,6 +1331,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
|||||||
|
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
|
if (!computeOp.getInputs().empty())
|
||||||
for (Value input : computeOp.getInputs())
|
for (Value input : computeOp.getInputs())
|
||||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> {
|
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> {
|
||||||
let summary = "Execute equivalent batched core bodies";
|
let summary = "Execute equivalent batched core bodies";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|||||||
@@ -257,8 +257,8 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
|||||||
auto newOp = PimCoreBatchOp::create(
|
auto newOp = PimCoreBatchOp::create(
|
||||||
rewriter, coreBatchOp.getLoc(), coreBatchOp.getLaneCountAttr(), ValueRange(weights), ValueRange(inputs));
|
rewriter, coreBatchOp.getLoc(), coreBatchOp.getLaneCountAttr(), ValueRange(weights), ValueRange(inputs));
|
||||||
newOp.getProperties().setOperandSegmentSizes({static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
newOp.getProperties().setOperandSegmentSizes({static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||||
if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName))
|
if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdsAttrName))
|
||||||
newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr);
|
newOp->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
|
||||||
|
|
||||||
rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin());
|
rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin());
|
||||||
for (Block& block : newOp.getBody())
|
for (Block& block : newOp.getBody())
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ add_pim_library(SpatialOps
|
|||||||
SpatialOpsVerify.cpp
|
SpatialOpsVerify.cpp
|
||||||
SpatialOpsCanonicalization.cpp
|
SpatialOpsCanonicalization.cpp
|
||||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
|
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
||||||
|
|||||||
@@ -102,6 +102,23 @@ def SpatConcatOp : SpatOp<"concat", []> {
|
|||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def SpatMapOp : SpatOp<"map", [SingleBlock]> {
|
||||||
|
let summary = "Apply the same lane-local region to many independent tensors";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Variadic<SpatTensor>:$inputs
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Variadic<SpatTensor>:$outputs
|
||||||
|
);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Communication
|
// Communication
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -184,6 +201,20 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
|||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> {
|
||||||
|
let summary = "Send multiple per-lane tensors through logical channels in a batch body";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
DenseI64ArrayAttr:$channelIds,
|
||||||
|
DenseI32ArrayAttr:$sourceCoreIds,
|
||||||
|
DenseI32ArrayAttr:$targetCoreIds,
|
||||||
|
Variadic<SpatTensor>:$inputs
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
||||||
|
|
||||||
@@ -201,11 +232,28 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
|||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
|
||||||
|
let summary = "Receive multiple per-lane tensors through logical channels in a batch body";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
DenseI64ArrayAttr:$channelIds,
|
||||||
|
DenseI32ArrayAttr:$sourceCoreIds,
|
||||||
|
DenseI32ArrayAttr:$targetCoreIds
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
Variadic<SpatTensor>:$outputs
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Math
|
// Math
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatWeightedVMMOp : SpatOp<"Wvmm", []> {
|
def SpatWeightedVMMOp : SpatOp<"wvmm", []> {
|
||||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter)
|
|||||||
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
|
printer << (delimiter == ListDelimiter::Square ? "]" : ")");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||||
|
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename EntryT, typename ParseEntryFn>
|
template <typename EntryT, typename ParseEntryFn>
|
||||||
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||||
ListDelimiter delimiter,
|
ListDelimiter delimiter,
|
||||||
@@ -75,13 +79,26 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename IntT>
|
template <typename IntT>
|
||||||
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
static ParseResult
|
||||||
if (parser.parseLSquare())
|
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||||
return failure();
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
if (succeeded(parser.parseOptionalRSquare()))
|
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<IntT> subgroup;
|
||||||
|
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(values, subgroup);
|
||||||
|
}
|
||||||
|
else {
|
||||||
int64_t first = 0;
|
int64_t first = 0;
|
||||||
if (parser.parseInteger(first))
|
if (parser.parseInteger(first))
|
||||||
return failure();
|
return failure();
|
||||||
@@ -118,8 +135,9 @@ static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorIm
|
|||||||
for (int64_t index = 0; index < repeatCount; ++index)
|
for (int64_t index = 0; index < repeatCount; ++index)
|
||||||
values.push_back(static_cast<IntT>(first));
|
values.push_back(static_cast<IntT>(first));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalRSquare()))
|
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||||
break;
|
break;
|
||||||
if (parser.parseComma())
|
if (parser.parseComma())
|
||||||
return failure();
|
return failure();
|
||||||
@@ -128,6 +146,14 @@ static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorIm
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
static ParseResult
|
||||||
|
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||||
|
if (parseOpenDelimiter(parser, delimiter))
|
||||||
|
return failure();
|
||||||
|
return parseCompressedIntegerEntries(parser, delimiter, values);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename RangeT, typename PrintEntryFn>
|
template <typename RangeT, typename PrintEntryFn>
|
||||||
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
||||||
for (size_t index = 0; index < entries.size();) {
|
for (size_t index = 0; index < entries.size();) {
|
||||||
@@ -146,35 +172,51 @@ static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename IntT>
|
template <typename IntT>
|
||||||
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
||||||
printer << "[";
|
struct FlatCompression {
|
||||||
for (size_t index = 0; index < values.size();) {
|
enum class Kind {
|
||||||
if (index != 0)
|
Single,
|
||||||
printer << ", ";
|
EqualRun,
|
||||||
|
Progression
|
||||||
auto findEqualRunEnd = [&](size_t start) {
|
|
||||||
size_t end = start + 1;
|
|
||||||
while (end < values.size() && values[end] == values[start])
|
|
||||||
++end;
|
|
||||||
return end;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t firstRunEnd = findEqualRunEnd(index);
|
Kind kind = Kind::Single;
|
||||||
size_t repeatCount = firstRunEnd - index;
|
size_t covered = 1;
|
||||||
|
size_t repeatCount = 1;
|
||||||
|
size_t progressionValueCount = 1;
|
||||||
|
int64_t step = 1;
|
||||||
|
IntT firstValue {};
|
||||||
|
IntT lastValue {};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto computeFlatCompression = [&](size_t start) {
|
||||||
|
FlatCompression compression;
|
||||||
|
compression.firstValue = values[start];
|
||||||
|
compression.lastValue = values[start];
|
||||||
|
|
||||||
|
auto findEqualRunEnd = [&](size_t runStart) {
|
||||||
|
size_t runEnd = runStart + 1;
|
||||||
|
while (runEnd < values.size() && values[runEnd] == values[runStart])
|
||||||
|
++runEnd;
|
||||||
|
return runEnd;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t firstRunEnd = findEqualRunEnd(start);
|
||||||
|
compression.repeatCount = firstRunEnd - start;
|
||||||
size_t progressionEnd = firstRunEnd;
|
size_t progressionEnd = firstRunEnd;
|
||||||
int64_t step = 0;
|
int64_t step = 0;
|
||||||
IntT lastValue = values[index];
|
IntT lastValue = values[start];
|
||||||
|
|
||||||
if (firstRunEnd < values.size()) {
|
if (firstRunEnd < values.size()) {
|
||||||
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
||||||
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[index]);
|
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
|
||||||
if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) {
|
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
|
||||||
progressionEnd = secondRunEnd;
|
progressionEnd = secondRunEnd;
|
||||||
lastValue = values[firstRunEnd];
|
lastValue = values[firstRunEnd];
|
||||||
size_t currentRunStart = secondRunEnd;
|
size_t currentRunStart = secondRunEnd;
|
||||||
while (currentRunStart < values.size()) {
|
while (currentRunStart < values.size()) {
|
||||||
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
||||||
if (currentRunEnd - currentRunStart != repeatCount)
|
if (currentRunEnd - currentRunStart != compression.repeatCount)
|
||||||
break;
|
break;
|
||||||
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
||||||
break;
|
break;
|
||||||
@@ -188,27 +230,99 @@ static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> val
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount;
|
compression.covered = 1;
|
||||||
if (progressionEnd > firstRunEnd && progressionValueCount >= 3) {
|
if (progressionEnd > firstRunEnd) {
|
||||||
printer << values[index] << " to " << lastValue;
|
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
|
||||||
if (step != 1)
|
if (progressionValueCount >= 3) {
|
||||||
printer << " by " << step;
|
compression.kind = FlatCompression::Kind::Progression;
|
||||||
if (repeatCount > 1)
|
compression.covered = progressionEnd - start;
|
||||||
printer << " x" << repeatCount;
|
compression.progressionValueCount = progressionValueCount;
|
||||||
index = progressionEnd;
|
compression.step = step;
|
||||||
continue;
|
compression.lastValue = lastValue;
|
||||||
|
return compression;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (repeatCount > 1) {
|
if (compression.repeatCount > 1) {
|
||||||
printer << values[index] << " x" << repeatCount;
|
compression.kind = FlatCompression::Kind::EqualRun;
|
||||||
index = firstRunEnd;
|
compression.covered = compression.repeatCount;
|
||||||
continue;
|
return compression;
|
||||||
}
|
}
|
||||||
|
|
||||||
printer << values[index];
|
return compression;
|
||||||
index = firstRunEnd;
|
};
|
||||||
|
|
||||||
|
auto findRepeatedSublist = [&](size_t start) {
|
||||||
|
size_t bestLength = 0;
|
||||||
|
size_t bestRepeatCount = 1;
|
||||||
|
size_t remaining = values.size() - start;
|
||||||
|
|
||||||
|
for (size_t length = 2; length * 2 <= remaining; ++length) {
|
||||||
|
size_t repeatCount = 1;
|
||||||
|
ArrayRef<IntT> candidate = values.slice(start, length);
|
||||||
|
while (start + (repeatCount + 1) * length <= values.size()
|
||||||
|
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
|
||||||
|
++repeatCount;
|
||||||
}
|
}
|
||||||
printer << "]";
|
|
||||||
|
if (repeatCount <= 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t covered = length * repeatCount;
|
||||||
|
size_t bestCovered = bestLength * bestRepeatCount;
|
||||||
|
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
|
||||||
|
bestLength = length;
|
||||||
|
bestRepeatCount = repeatCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::pair(bestLength, bestRepeatCount);
|
||||||
|
};
|
||||||
|
|
||||||
|
printOpenDelimiter(printer, delimiter);
|
||||||
|
for (size_t index = 0; index < values.size();) {
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
|
||||||
|
FlatCompression flat = computeFlatCompression(index);
|
||||||
|
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
|
||||||
|
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
|
||||||
|
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
|
||||||
|
printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren);
|
||||||
|
printer << " x" << sublistRepeatCount;
|
||||||
|
index += repeatedSublistCoverage;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
switch (flat.kind) {
|
||||||
|
case FlatCompression::Kind::Progression:
|
||||||
|
printer << flat.firstValue << " to " << flat.lastValue;
|
||||||
|
if (flat.step != 1)
|
||||||
|
printer << " by " << flat.step;
|
||||||
|
if (flat.repeatCount > 1)
|
||||||
|
printer << " x" << flat.repeatCount;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
case FlatCompression::Kind::EqualRun:
|
||||||
|
printer << flat.firstValue << " x" << flat.repeatCount;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
case FlatCompression::Kind::Single:
|
||||||
|
printer << flat.firstValue;
|
||||||
|
index += flat.covered;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printCloseDelimiter(printer, delimiter);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||||
|
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename IntT>
|
||||||
|
static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||||
|
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
||||||
@@ -267,6 +381,165 @@ static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, List
|
|||||||
printCloseDelimiter(printer, delimiter);
|
printCloseDelimiter(printer, delimiter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
|
||||||
|
static ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands);
|
||||||
|
static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty);
|
||||||
|
|
||||||
|
static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
||||||
|
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SmallVector<Value> valueVec(values.begin(), values.end());
|
||||||
|
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
|
||||||
|
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
|
||||||
|
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||||
|
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
SmallVector<Type> typeVec(types.begin(), types.end());
|
||||||
|
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
|
||||||
|
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
|
||||||
|
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) {
|
||||||
|
printer << "[";
|
||||||
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
for (size_t index = 0; index < tupleSize; ++index) {
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printer.printOperand(values[index]);
|
||||||
|
}
|
||||||
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
printer << " x" << (values.size() / tupleSize) << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) {
|
||||||
|
printer << "[";
|
||||||
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
for (size_t index = 0; index < tupleSize; ++index) {
|
||||||
|
if (index != 0)
|
||||||
|
printer << ", ";
|
||||||
|
printer.printType(types[index]);
|
||||||
|
}
|
||||||
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
printer << " x" << (types.size() / tupleSize) << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (parser.parseLSquare())
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalRSquare()))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
|
||||||
|
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(operands, tupleOperands);
|
||||||
|
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
tupleOperands.clear();
|
||||||
|
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(operands, tupleOperands);
|
||||||
|
}
|
||||||
|
return parser.parseRSquare();
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (parseOneCompressedOperandEntry(parser, operands))
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalRSquare()))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl<Type>& types) {
|
||||||
|
if (parser.parseLSquare())
|
||||||
|
return failure();
|
||||||
|
if (succeeded(parser.parseOptionalRSquare()))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
SmallVector<Type> tupleTypes;
|
||||||
|
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(types, tupleTypes);
|
||||||
|
|
||||||
|
while (succeeded(parser.parseOptionalComma())) {
|
||||||
|
if (parser.parseLParen())
|
||||||
|
return failure();
|
||||||
|
tupleTypes.clear();
|
||||||
|
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
llvm::append_range(types, tupleTypes);
|
||||||
|
}
|
||||||
|
return parser.parseRSquare();
|
||||||
|
}
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
Type type;
|
||||||
|
if (parser.parseType(type))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t repeatCount = 1;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||||
|
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||||
|
}
|
||||||
|
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||||
|
types.push_back(type);
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalRSquare()))
|
||||||
|
return success();
|
||||||
|
if (parser.parseComma())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||||
OpAsmParser::UnresolvedOperand firstOperand,
|
OpAsmParser::UnresolvedOperand firstOperand,
|
||||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
@@ -440,19 +713,88 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
|||||||
return parser.getBuilder().getI32IntegerAttr(value);
|
return parser.getBuilder().getI32IntegerAttr(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildImplicitRegionArgs(OpAsmParser& parser,
|
static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
|
||||||
ArrayRef<Type> inputTypes,
|
if (block.getNumArguments() == 0) {
|
||||||
SmallVectorImpl<std::string>& generatedNames,
|
printer << "() = ()";
|
||||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
return;
|
||||||
generatedNames.reserve(inputTypes.size());
|
|
||||||
arguments.reserve(inputTypes.size());
|
|
||||||
for (auto [index, inputType] : llvm::enumerate(inputTypes)) {
|
|
||||||
generatedNames.push_back("arg" + std::to_string(index + 1));
|
|
||||||
OpAsmParser::Argument arg;
|
|
||||||
arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0};
|
|
||||||
arg.type = inputType;
|
|
||||||
arguments.push_back(arg);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (block.getNumArguments() == 1) {
|
||||||
|
printer.printOperand(block.getArgument(0));
|
||||||
|
printer << " = ";
|
||||||
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
|
||||||
|
printer << " = ";
|
||||||
|
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
|
||||||
|
OpAsmParser::Argument firstArgument,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||||
|
OpAsmParser::Argument lastArgument;
|
||||||
|
if (parser.parseArgument(lastArgument))
|
||||||
|
return failure();
|
||||||
|
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|
||||||
|
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
|
||||||
|
}
|
||||||
|
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
|
||||||
|
OpAsmParser::Argument argument;
|
||||||
|
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
|
||||||
|
arguments.push_back(argument);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
arguments.push_back(firstArgument);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
OpAsmParser::Argument firstArgument;
|
||||||
|
if (parser.parseArgument(firstArgument))
|
||||||
|
return failure();
|
||||||
|
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||||
|
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
|
||||||
|
argument.type = inputType;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseArgumentBindings(OpAsmParser& parser,
|
||||||
|
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||||
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||||
|
if (succeeded(parser.parseOptionalLParen())) {
|
||||||
|
if (succeeded(parser.parseOptionalRParen())) {
|
||||||
|
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
OpAsmParser::Argument firstArgument;
|
||||||
|
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
|
||||||
|
return failure();
|
||||||
|
while (succeeded(parser.parseOptionalComma()))
|
||||||
|
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||||
|
return failure();
|
||||||
|
if (parser.parseRParen() || parser.parseEqual()
|
||||||
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
OpAsmParser::Argument argument;
|
||||||
|
if (parser.parseArgument(argument) || parser.parseEqual()
|
||||||
|
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||||
|
return failure();
|
||||||
|
arguments.push_back(argument);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -519,8 +861,8 @@ ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result
|
|||||||
|
|
||||||
void SpatConcatOp::print(OpAsmPrinter& printer) {
|
void SpatConcatOp::print(OpAsmPrinter& printer) {
|
||||||
printer << " axis " << getAxis();
|
printer << " axis " << getAxis();
|
||||||
printer << " args = ";
|
printer << " ";
|
||||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
printCompressedValueSequence(printer, getInputs());
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
|
printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()});
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||||
@@ -537,11 +879,7 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
|
if (parser.parseKeyword("axis") || parser.parseInteger(axis))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
if (parseCompressedOperandSequence(parser, inputs)) {
|
||||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -563,14 +901,54 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SpatMapOp::print(OpAsmPrinter& printer) {
|
||||||
|
printer << " ";
|
||||||
|
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||||
|
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||||
|
printer << " : ";
|
||||||
|
printer.printType(getInputs().front().getType());
|
||||||
|
printer << " -> ";
|
||||||
|
printer.printType(getOutputs().front().getType());
|
||||||
|
printer << " ";
|
||||||
|
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
|
Type inputType;
|
||||||
|
Type outputType;
|
||||||
|
|
||||||
|
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||||
|
return failure();
|
||||||
|
if (inputs.empty())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "map requires at least one input");
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)
|
||||||
|
|| parser.parseArrow() || parser.parseType(outputType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<Type> inputTypes(inputs.size(), inputType);
|
||||||
|
SmallVector<Type> outputTypes(inputs.size(), outputType);
|
||||||
|
if (regionArgs.size() != inputs.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
|
if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
|
applyArgumentTypes(inputTypes, regionArgs);
|
||||||
|
Region* body = result.addRegion();
|
||||||
|
return parser.parseRegion(*body, regionArgs);
|
||||||
|
}
|
||||||
|
|
||||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||||
printer << " args = ";
|
printer << " ";
|
||||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||||
|
|
||||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
printer << " core_id " << coreIdAttr.getInt();
|
printer << " coreId " << coreIdAttr.getInt();
|
||||||
|
|
||||||
printer.printOptionalAttrDict((*this)->getAttrs(),
|
printer.printOptionalAttrDict((*this)->getAttrs(),
|
||||||
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
{getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
||||||
@@ -587,7 +965,6 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
|||||||
|
|
||||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
SmallVector<std::string> generatedArgNames;
|
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
SmallVector<Type> weightTypes;
|
SmallVector<Type> weightTypes;
|
||||||
@@ -598,15 +975,10 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
|
||||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id"));
|
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||||
if (hasCoreId && parser.parseInteger(coreId))
|
if (hasCoreId && parser.parseInteger(coreId))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -622,9 +994,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
if (inputs.size() != inputTypes.size())
|
if (inputs.size() != inputTypes.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
|
if (regionArgs.size() != inputs.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||||
return parser.emitError(parser.getCurrentLocation(),
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
"core_id cannot be specified both positionally and in attr-dict");
|
"coreId cannot be specified both positionally and in attr-dict");
|
||||||
|
|
||||||
auto& builder = parser.getBuilder();
|
auto& builder = parser.getBuilder();
|
||||||
result.addAttribute(
|
result.addAttribute(
|
||||||
@@ -639,26 +1013,33 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
result.addTypes(outputTypes);
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
Region* body = result.addRegion();
|
Region* body = result.addRegion();
|
||||||
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
|
applyArgumentTypes(inputTypes, regionArgs);
|
||||||
return parser.parseRegion(*body, regionArgs);
|
return parser.parseRegion(*body, regionArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||||
printer << " lanes " << getLaneCount() << " ";
|
printer << " lanes " << getLaneCount() << " ";
|
||||||
|
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||||
|
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||||
|
printValueTupleRun(printer, getWeights(), weightsPerLane);
|
||||||
|
else
|
||||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||||
printer << " args = ";
|
printer << " ";
|
||||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Paren);
|
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||||
|
|
||||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName)) {
|
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||||
printer << " core_ids ";
|
printer << " coreIds ";
|
||||||
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
printCompressedIntegerList(printer, coreIdsAttr.asArrayRef());
|
||||||
}
|
}
|
||||||
|
|
||||||
printer.printOptionalAttrDict(
|
printer.printOptionalAttrDict(
|
||||||
(*this)->getAttrs(),
|
(*this)->getAttrs(),
|
||||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName});
|
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||||
|
|
||||||
printer << " : ";
|
printer << " : ";
|
||||||
|
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||||
|
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane);
|
||||||
|
else
|
||||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||||
@@ -671,7 +1052,6 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
|||||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
int32_t laneCount = 0;
|
int32_t laneCount = 0;
|
||||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||||
SmallVector<std::string> generatedArgNames;
|
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
SmallVector<Type> weightTypes;
|
SmallVector<Type> weightTypes;
|
||||||
@@ -682,24 +1062,18 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
if (parseCompressedOrTupleOperandList(parser, weights))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalKeyword("args"))) {
|
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs))
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
|
||||||
else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids"));
|
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|| parseCompressedRepeatedList(
|
|| parseCompressedOrTupleTypeList(parser, weightTypes)
|
||||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
|
||||||
|| parseCompressedRepeatedList(
|
|| parseCompressedRepeatedList(
|
||||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||||
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
|| parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||||
@@ -709,8 +1083,11 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||||
if (inputs.size() != inputTypes.size())
|
if (inputs.size() != inputTypes.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
if (regionArgs.size() != inputs.size())
|
||||||
return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict");
|
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||||
|
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"coreIds cannot be specified both positionally and in attr-dict");
|
||||||
|
|
||||||
auto& builder = parser.getBuilder();
|
auto& builder = parser.getBuilder();
|
||||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||||
@@ -718,7 +1095,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
"operandSegmentSizes",
|
"operandSegmentSizes",
|
||||||
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||||
if (hasCoreIds)
|
if (hasCoreIds)
|
||||||
result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||||
|
|
||||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands)
|
||||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands))
|
||||||
@@ -726,7 +1103,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
result.addTypes(outputTypes);
|
result.addTypes(outputTypes);
|
||||||
|
|
||||||
Region* body = result.addRegion();
|
Region* body = result.addRegion();
|
||||||
buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs);
|
applyArgumentTypes(inputTypes, regionArgs);
|
||||||
return parser.parseRegion(*body, regionArgs);
|
return parser.parseRegion(*body, regionArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -867,6 +1244,55 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r
|
|||||||
return parser.resolveOperand(input, inputType, result.operands);
|
return parser.resolveOperand(input, inputType, result.operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||||
|
printer << " ";
|
||||||
|
printCompressedValueSequence(printer, getInputs());
|
||||||
|
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||||
|
printer.printOptionalAttrDict(
|
||||||
|
(*this)->getAttrs(),
|
||||||
|
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||||
|
printer << " : ";
|
||||||
|
printCompressedTypeSequence(printer, TypeRange(getInputs()));
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||||
|
SmallVector<Type> inputTypes;
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
|
||||||
|
if (parseCompressedOperandSequence(parser, inputs))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||||
|
if (hasMetadata) {
|
||||||
|
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||||
|
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||||
|
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|
|| parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (inputs.size() != inputTypes.size())
|
||||||
|
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||||
|
if (hasMetadata
|
||||||
|
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||||
|
|| result.attributes.get("targetCoreIds")))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||||
|
if (hasMetadata) {
|
||||||
|
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||||
|
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||||
|
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||||
|
}
|
||||||
|
|
||||||
|
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||||
|
}
|
||||||
|
|
||||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||||
printer.printOptionalAttrDict(
|
printer.printOptionalAttrDict(
|
||||||
@@ -908,5 +1334,47 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) {
|
||||||
|
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||||
|
printer.printOptionalAttrDict(
|
||||||
|
(*this)->getAttrs(),
|
||||||
|
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||||
|
printer << " : ";
|
||||||
|
printCompressedTypeSequence(printer, getResultTypes());
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
SmallVector<Type> outputTypes;
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
|
||||||
|
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||||
|
if (hasMetadata) {
|
||||||
|
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||||
|
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||||
|
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (hasMetadata
|
||||||
|
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||||
|
|| result.attributes.get("targetCoreIds")))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||||
|
if (hasMetadata) {
|
||||||
|
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||||
|
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||||
|
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||||
|
}
|
||||||
|
|
||||||
|
result.addTypes(outputTypes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -83,13 +83,13 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
||||||
if (auto computeOp = dyn_cast<SpatCompute>(weightedOp->getParentOp()))
|
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
|
||||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||||
|
|
||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weightedOp->getParentOp()))
|
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
|
||||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||||
|
|
||||||
if (auto batchOp = dyn_cast<SpatComputeBatch>(weightedOp->getParentOp())) {
|
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
|
||||||
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
||||||
return failure();
|
return failure();
|
||||||
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
||||||
@@ -144,6 +144,23 @@ static LogicalResult verifyBatchChannelSizes(Operation* op,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyManyBatchChannelSizes(Operation* op,
|
||||||
|
ArrayRef<int64_t> channelIds,
|
||||||
|
ArrayRef<int32_t> sourceCoreIds,
|
||||||
|
ArrayRef<int32_t> targetCoreIds,
|
||||||
|
size_t valueCount) {
|
||||||
|
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||||
|
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||||
|
|
||||||
|
auto laneCount = getParentBatchLaneCount(op);
|
||||||
|
if (failed(laneCount))
|
||||||
|
return op->emitError("must be nested inside spat.compute_batch");
|
||||||
|
if (channelIds.size() != valueCount * static_cast<size_t>(*laneCount))
|
||||||
|
return op->emitError("channel metadata length must match the number of values times parent laneCount");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
if (!yieldOp)
|
if (!yieldOp)
|
||||||
@@ -306,6 +323,39 @@ LogicalResult SpatConcatOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatMapOp::verify() {
|
||||||
|
if (getInputs().empty())
|
||||||
|
return emitError("requires at least one input");
|
||||||
|
if (getOutputs().size() != getInputs().size())
|
||||||
|
return emitError("number of outputs must match number of inputs");
|
||||||
|
|
||||||
|
Type inputType = getInputs().front().getType();
|
||||||
|
for (Value input : getInputs().drop_front())
|
||||||
|
if (input.getType() != inputType)
|
||||||
|
return emitError("all inputs must have the same type");
|
||||||
|
|
||||||
|
Type outputType = getOutputs().front().getType();
|
||||||
|
for (Value output : getOutputs().drop_front())
|
||||||
|
if (output.getType() != outputType)
|
||||||
|
return emitError("all outputs must have the same type");
|
||||||
|
|
||||||
|
Block& block = getBody().front();
|
||||||
|
if (block.getNumArguments() != 1)
|
||||||
|
return emitError("body must have exactly one block argument");
|
||||||
|
if (block.getArgument(0).getType() != inputType)
|
||||||
|
return emitError("body block argument type must match input type");
|
||||||
|
|
||||||
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
|
if (!yieldOp)
|
||||||
|
return emitError("body must terminate with spat.yield");
|
||||||
|
if (yieldOp.getNumOperands() != 1)
|
||||||
|
return emitError("body yield must produce exactly one value");
|
||||||
|
if (yieldOp.getOperand(0).getType() != outputType)
|
||||||
|
return emitError("body yield type must match output type");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult SpatCompute::verify() {
|
LogicalResult SpatCompute::verify() {
|
||||||
auto& block = getBody().front();
|
auto& block = getBody().front();
|
||||||
if (block.mightHaveTerminator()) {
|
if (block.mightHaveTerminator()) {
|
||||||
@@ -365,10 +415,24 @@ LogicalResult SpatChannelSendBatchOp::verify() {
|
|||||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatChannelSendManyBatchOp::verify() {
|
||||||
|
if (failed(verifyManyBatchChannelSizes(
|
||||||
|
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||||
|
return failure();
|
||||||
|
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch");
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatChannelReceiveManyBatchOp::verify() {
|
||||||
|
if (failed(verifyManyBatchChannelSizes(
|
||||||
|
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||||
|
return failure();
|
||||||
|
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch");
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult SpatComputeBatch::verify() {
|
LogicalResult SpatComputeBatch::verify() {
|
||||||
int32_t count = getLaneCount();
|
int32_t count = getLaneCount();
|
||||||
if (count <= 0)
|
if (count <= 0)
|
||||||
@@ -405,18 +469,18 @@ LogicalResult SpatComputeBatch::verify() {
|
|||||||
return emitError("all outputs must have the same type");
|
return emitError("all outputs must have the same type");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) {
|
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdsAttrName)) {
|
||||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||||
if (!coreIdsAttr)
|
if (!coreIdsAttr)
|
||||||
return emitError("compute_batch core_id attribute must be a dense i32 array");
|
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
||||||
if (coreIdsAttr.size() != laneCountSz)
|
if (coreIdsAttr.size() != laneCountSz)
|
||||||
return emitError("compute_batch core_id array length must match laneCount");
|
return emitError("compute_batch coreIds array length must match laneCount");
|
||||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
||||||
return emitError("compute_batch core_id values must be positive");
|
return emitError("compute_batch coreIds values must be positive");
|
||||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||||
if (!seenCoreIds.insert(coreId).second)
|
if (!seenCoreIds.insert(coreId).second)
|
||||||
return emitError("compute_batch core_id values must be distinct");
|
return emitError("compute_batch coreIds values must be distinct");
|
||||||
}
|
}
|
||||||
|
|
||||||
Block& block = getBody().front();
|
Block& block = getBody().front();
|
||||||
|
|||||||
@@ -184,14 +184,40 @@ std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
|||||||
|
|
||||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
||||||
SmallVector<ComputeInstance> instances;
|
SmallVector<ComputeInstance> instances;
|
||||||
|
auto isUsedAsWeightOnly = [](Operation* producerOp) {
|
||||||
|
if (producerOp->getNumResults() == 0)
|
||||||
|
return false;
|
||||||
|
for (Value result : producerOp->getResults()) {
|
||||||
|
if (result.use_empty())
|
||||||
|
return false;
|
||||||
|
for (Operation* user : result.getUsers()) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||||
|
if (!llvm::is_contained(compute.getWeights(), result))
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||||
|
if (!llvm::is_contained(batch.getWeights(), result))
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
for (Region& region : entryOp->getRegions()) {
|
for (Region& region : entryOp->getRegions()) {
|
||||||
for (Block& block : region) {
|
for (Block& block : region) {
|
||||||
for (Operation& op : block) {
|
for (Operation& op : block) {
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||||
|
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||||
|
continue;
|
||||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
instances.push_back({spatCompute.getOperation(), 0, 1});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||||
|
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||||
|
continue;
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
||||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
||||||
@@ -582,10 +608,13 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
result.dominanceOrderCompute.reserve(computeInstances.size());
|
||||||
|
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||||
size_t cpu = originalComputeToCpu[originalIndex];
|
size_t cpu = originalComputeToCpu[originalIndex];
|
||||||
result.dominanceOrderCompute.push_back(computeInstance);
|
result.dominanceOrderCompute.push_back(computeInstance);
|
||||||
result.computeToCpuMap[computeInstance] = cpu;
|
result.computeToCpuMap[computeInstance] = cpu;
|
||||||
|
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
|
||||||
|
result.computeToAestMap[computeInstance] = originalIndex;
|
||||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
result.cpuToLastComputeMap[cpu] = computeInstance;
|
||||||
}
|
}
|
||||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||||
@@ -603,8 +632,12 @@ DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<Com
|
|||||||
if (scheduledTasks.empty())
|
if (scheduledTasks.empty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
for (const auto& task : scheduledTasks)
|
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||||
result.computeToCpuMap[computeInstances[task.nodeIndex]] = cpu;
|
ComputeInstance instance = computeInstances[task.nodeIndex];
|
||||||
|
result.computeToCpuMap[instance] = cpu;
|
||||||
|
result.computeToCpuSlotMap[instance] = slot;
|
||||||
|
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||||
|
}
|
||||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
||||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
||||||
}
|
}
|
||||||
@@ -671,6 +704,16 @@ DCPAnalysisResult DCPAnalysis::run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (coresCount.getValue() > 0) {
|
||||||
|
size_t schedulingCpuBudget = getSchedulingCpuBudget();
|
||||||
|
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
|
||||||
|
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||||
|
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||||
|
});
|
||||||
|
if (needsExactScheduledBatches)
|
||||||
|
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
if (dcpCriticalWindowSize.getValue() == 0)
|
if (dcpCriticalWindowSize.getValue() == 0)
|
||||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ struct ComputeInstance {
|
|||||||
struct DCPAnalysisResult {
|
struct DCPAnalysisResult {
|
||||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||||
|
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
|
#include "mlir/Analysis/TopologicalSortUtils.h"
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
@@ -34,9 +37,9 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "DCPGraph/DCPAnalysis.hpp"
|
#include "DCPGraph/DCPAnalysis.hpp"
|
||||||
|
#include "RegularOpCompaction.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -61,7 +64,7 @@ static size_t getFastPathCpuBudget() {
|
|||||||
|
|
||||||
static size_t getBatchChunkTargetCount(int32_t laneCount) {
|
static size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||||
assert(laneCount > 0 && "laneCount must be positive");
|
assert(laneCount > 0 && "laneCount must be positive");
|
||||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getFastPathCpuBudget()));
|
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, static_cast<size_t>(getFastPathCpuBudget())));
|
||||||
}
|
}
|
||||||
|
|
||||||
static ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
static ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||||
@@ -129,8 +132,25 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
|||||||
|
|
||||||
static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast<int32_t>(schedulerCpu + 1); }
|
static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast<int32_t>(schedulerCpu + 1); }
|
||||||
|
|
||||||
|
static size_t getMaterializationCpuBudget(size_t laneCount) {
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
return static_cast<size_t>(coresCount.getValue());
|
||||||
|
return std::max<size_t>(1, laneCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int32_t> getMaterializedBatchCoreIds(size_t startCpu, size_t laneCount) {
|
||||||
|
size_t cpuBudget = getMaterializationCpuBudget(laneCount);
|
||||||
|
assert(laneCount <= cpuBudget && "materialized batch exceeds available CPUs");
|
||||||
|
|
||||||
|
SmallVector<int32_t> coreIds;
|
||||||
|
coreIds.reserve(laneCount);
|
||||||
|
for (size_t laneOffset = 0; laneOffset < laneCount; ++laneOffset)
|
||||||
|
coreIds.push_back(getPhysicalCoreId((startCpu + laneOffset) % cpuBudget));
|
||||||
|
return coreIds;
|
||||||
|
}
|
||||||
|
|
||||||
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
|
static SmallVector<int32_t> getBatchCoreIds(Operation* op, size_t laneCount) {
|
||||||
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto coreIdsAttr = op->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||||
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto coreIdAttr = op->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
|
return SmallVector<int32_t>(laneCount, static_cast<int32_t>(coreIdAttr.getInt()));
|
||||||
@@ -143,6 +163,14 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||||
|
|
||||||
|
static std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||||
|
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||||
|
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||||
if (!lhs || !rhs)
|
if (!lhs || !rhs)
|
||||||
return false;
|
return false;
|
||||||
@@ -152,6 +180,8 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
|||||||
return false;
|
return false;
|
||||||
if (lhs.getWeights().size() != rhs.getWeights().size())
|
if (lhs.getWeights().size() != rhs.getWeights().size())
|
||||||
return false;
|
return false;
|
||||||
|
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
||||||
|
return false;
|
||||||
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@@ -277,7 +307,7 @@ static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp,
|
|||||||
|
|
||||||
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
|
SmallVector<int32_t> coreIds = getBatchCoreIds(batch, static_cast<size_t>(batch.getLaneCount()));
|
||||||
if (!coreIds.empty())
|
if (!coreIds.empty())
|
||||||
newBatch->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||||
|
|
||||||
auto* newBlock =
|
auto* newBlock =
|
||||||
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
|
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef<Location> {});
|
||||||
@@ -521,141 +551,6 @@ void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<SpatCompute>()) {
|
|
||||||
Block& block = compute.getBody().front();
|
|
||||||
for (auto it = block.begin(); it != block.end();) {
|
|
||||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
|
||||||
if (receiveOp) {
|
|
||||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
|
||||||
Type outputType = receiveOp.getOutput().getType();
|
|
||||||
auto runIt = it;
|
|
||||||
while (runIt != block.end()) {
|
|
||||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
|
||||||
if (!current || current.getOutput().getType() != outputType)
|
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
|
||||||
struct ReceiveEntry {
|
|
||||||
spatial::SpatChannelReceiveOp op;
|
|
||||||
size_t originalIndex = 0;
|
|
||||||
uint32_t sourceCoreId = 0;
|
|
||||||
uint32_t targetCoreId = 0;
|
|
||||||
uint64_t channelId = 0;
|
|
||||||
};
|
|
||||||
SmallVector<ReceiveEntry> sortedEntries;
|
|
||||||
sortedEntries.reserve(run.size());
|
|
||||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
|
||||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
|
||||||
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
|
||||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
|
||||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
SmallVector<Type> outputTypes;
|
|
||||||
channelIds.reserve(sortedEntries.size());
|
|
||||||
sourceCoreIds.reserve(sortedEntries.size());
|
|
||||||
targetCoreIds.reserve(sortedEntries.size());
|
|
||||||
outputTypes.reserve(sortedEntries.size());
|
|
||||||
for (ReceiveEntry& entry : sortedEntries) {
|
|
||||||
(void) entry;
|
|
||||||
channelIds.push_back(nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
|
||||||
outputTypes.push_back(entry.op.getOutput().getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.front());
|
|
||||||
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
|
|
||||||
run.front().getLoc(),
|
|
||||||
TypeRange(outputTypes),
|
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
|
||||||
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
|
||||||
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
|
|
||||||
for (auto op : run)
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
|
|
||||||
it = compactReceive->getIterator();
|
|
||||||
++it;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
|
||||||
if (sendOp) {
|
|
||||||
SmallVector<spatial::SpatChannelSendOp> run;
|
|
||||||
Type inputType = sendOp.getInput().getType();
|
|
||||||
auto runIt = it;
|
|
||||||
while (runIt != block.end()) {
|
|
||||||
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
|
||||||
if (!current || current.getInput().getType() != inputType)
|
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
|
||||||
struct SendEntry {
|
|
||||||
spatial::SpatChannelSendOp op;
|
|
||||||
uint32_t sourceCoreId = 0;
|
|
||||||
uint32_t targetCoreId = 0;
|
|
||||||
uint64_t channelId = 0;
|
|
||||||
};
|
|
||||||
SmallVector<SendEntry> sortedEntries;
|
|
||||||
sortedEntries.reserve(run.size());
|
|
||||||
for (auto op : run)
|
|
||||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
|
||||||
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
|
||||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
|
||||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
|
||||||
SmallVector<int32_t> targetCoreIds;
|
|
||||||
SmallVector<Value> inputs;
|
|
||||||
channelIds.reserve(sortedEntries.size());
|
|
||||||
sourceCoreIds.reserve(sortedEntries.size());
|
|
||||||
targetCoreIds.reserve(sortedEntries.size());
|
|
||||||
inputs.reserve(sortedEntries.size());
|
|
||||||
for (SendEntry& entry : sortedEntries) {
|
|
||||||
(void) entry;
|
|
||||||
channelIds.push_back(nextChannelId++);
|
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
|
||||||
inputs.push_back(entry.op.getInput());
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.front());
|
|
||||||
spatial::SpatChannelSendManyOp::create(rewriter,
|
|
||||||
run.front().getLoc(),
|
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
|
||||||
ValueRange(inputs));
|
|
||||||
for (auto op : run)
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
|
|
||||||
it = runIt;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||||
@@ -728,7 +623,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
rebatched.getProperties().setOperandSegmentSizes(
|
rebatched.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||||
if (haveAllCoreIds)
|
if (haveAllCoreIds)
|
||||||
rebatched->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||||
|
|
||||||
SmallVector<Type> blockArgTypes;
|
SmallVector<Type> blockArgTypes;
|
||||||
SmallVector<Location> blockArgLocs;
|
SmallVector<Location> blockArgLocs;
|
||||||
@@ -841,10 +736,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto compute : group) {
|
for (auto compute : group) {
|
||||||
|
compute->removeAttr(kRebatchPhaseAttrName);
|
||||||
consumed.insert(compute);
|
consumed.insert(compute);
|
||||||
rewriter.eraseOp(compute);
|
rewriter.eraseOp(compute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||||
|
compute->removeAttr(kRebatchPhaseAttrName);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ComputeMotifInfo {
|
struct ComputeMotifInfo {
|
||||||
@@ -1329,8 +1228,9 @@ public:
|
|||||||
|
|
||||||
LazyInsertComputeResult(
|
LazyInsertComputeResult(
|
||||||
ComputeValueResults computeValueResults,
|
ComputeValueResults computeValueResults,
|
||||||
|
size_t producerCpu,
|
||||||
std::function<std::pair<ChannelInfo, std::function<void(InsertPoint)>>(size_t, size_t)> channelInserter)
|
std::function<std::pair<ChannelInfo, std::function<void(InsertPoint)>>(size_t, size_t)> channelInserter)
|
||||||
: computeResults(computeValueResults), channelInserter(channelInserter) {}
|
: computeResults(computeValueResults), producerCpu(producerCpu), channelInserter(channelInserter) {}
|
||||||
|
|
||||||
struct ChannelOrLocalOp {
|
struct ChannelOrLocalOp {
|
||||||
Value data;
|
Value data;
|
||||||
@@ -1339,6 +1239,9 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex, size_t targetCpu) {
|
ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex, size_t targetCpu) {
|
||||||
|
if (targetCpu == producerCpu)
|
||||||
|
return {computeResults.getOuter(resultIndex), false, {}};
|
||||||
|
|
||||||
Value innerValue = computeResults.getInner(resultIndex);
|
Value innerValue = computeResults.getInner(resultIndex);
|
||||||
auto [channelInfo, channelSendInserter] = channelInserter(resultIndex, targetCpu);
|
auto [channelInfo, channelSendInserter] = channelInserter(resultIndex, targetCpu);
|
||||||
InsertPoint sendInsertPoint;
|
InsertPoint sendInsertPoint;
|
||||||
@@ -1353,6 +1256,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
ComputeValueResults computeResults;
|
ComputeValueResults computeResults;
|
||||||
|
size_t producerCpu = 0;
|
||||||
std::function<std::pair<ChannelInfo, std::function<void(InsertPoint)>>(size_t, size_t)> channelInserter;
|
std::function<std::pair<ChannelInfo, std::function<void(InsertPoint)>>(size_t, size_t)> channelInserter;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1378,28 +1282,158 @@ public:
|
|||||||
mergeTriviallyConnectedComputes(getOperation());
|
mergeTriviallyConnectedComputes(getOperation());
|
||||||
emitMotifProfile(getOperation());
|
emitMotifProfile(getOperation());
|
||||||
|
|
||||||
|
func::FuncOp func = getOperation();
|
||||||
|
Location loc = func.getLoc();
|
||||||
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
||||||
DenseSet<ComputeInstance> materializedInstances;
|
|
||||||
for (size_t index = 0; index < analysisResult.dominanceOrderCompute.size(); ++index) {
|
|
||||||
ComputeInstance currentInstance = analysisResult.dominanceOrderCompute[index];
|
|
||||||
if (!materializedInstances.insert(currentInstance).second)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
size_t cpu = analysisResult.computeToCpuMap.at(currentInstance);
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(currentInstance.op)) {
|
|
||||||
createNewBatchCompute(batch, currentInstance.laneStart, currentInstance.laneCount, cpu, analysisResult);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto scalarCompute = cast<SpatCompute>(currentInstance.op);
|
|
||||||
auto [newCompute, computeValueResults] = createNewComputeNode(scalarCompute, cpu, analysisResult);
|
|
||||||
newComputeNodeResults.insert({currentInstance, createLazyComputeResult(newCompute, computeValueResults, cpu)});
|
|
||||||
}
|
|
||||||
|
|
||||||
DenseSet<Operation*> toEraseSet;
|
DenseSet<Operation*> toEraseSet;
|
||||||
for (ComputeInstance instance : analysisResult.dominanceOrderCompute)
|
for (ComputeInstance instance : analysisResult.dominanceOrderCompute)
|
||||||
toEraseSet.insert(instance.op);
|
toEraseSet.insert(instance.op);
|
||||||
|
|
||||||
|
struct ScheduledTask {
|
||||||
|
ComputeInstance key;
|
||||||
|
Operation* sourceOp = nullptr;
|
||||||
|
size_t cpu = 0;
|
||||||
|
size_t slot = 0;
|
||||||
|
size_t order = 0;
|
||||||
|
};
|
||||||
|
struct ChannelInfo {
|
||||||
|
int64_t channelId = -1;
|
||||||
|
int32_t sourceCoreId = -1;
|
||||||
|
int32_t targetCoreId = -1;
|
||||||
|
};
|
||||||
|
struct CpuProgram {
|
||||||
|
SpatCompute op;
|
||||||
|
Block* block = nullptr;
|
||||||
|
DenseMap<Value, Value> externalInputMap;
|
||||||
|
DenseMap<Value, size_t> weightToIndex;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto getTaskInputs = [&](const ScheduledTask& task) {
|
||||||
|
SmallVector<Value> inputs;
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp)) {
|
||||||
|
llvm::append_range(inputs, compute.getInputs());
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
auto batch = cast<SpatComputeBatch>(task.sourceOp);
|
||||||
|
for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane)
|
||||||
|
if (!batch.getInputs().empty())
|
||||||
|
inputs.push_back(batch.getInputs()[lane]);
|
||||||
|
return inputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto getTaskWeights = [&](const ScheduledTask& task) {
|
||||||
|
SmallVector<Value> weights;
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp)) {
|
||||||
|
llvm::append_range(weights, compute.getWeights());
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
auto batch = cast<SpatComputeBatch>(task.sourceOp);
|
||||||
|
for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane)
|
||||||
|
weights.push_back(batch.getWeights()[lane]);
|
||||||
|
return weights;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto getTaskOutputValues = [&](const ScheduledTask& task) {
|
||||||
|
SmallVector<Value> outputs;
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp)) {
|
||||||
|
for (Value result : compute.getResults())
|
||||||
|
outputs.push_back(result);
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
auto batch = cast<SpatComputeBatch>(task.sourceOp);
|
||||||
|
for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane)
|
||||||
|
if (!batch.getOutputs().empty())
|
||||||
|
outputs.push_back(batch.getOutputs()[lane]);
|
||||||
|
return outputs;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto getTaskOutputTypes = [&](const ScheduledTask& task) {
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp)) {
|
||||||
|
llvm::append_range(resultTypes, compute.getResultTypes());
|
||||||
|
return resultTypes;
|
||||||
|
}
|
||||||
|
auto batch = cast<SpatComputeBatch>(task.sourceOp);
|
||||||
|
for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane)
|
||||||
|
if (!batch.getOutputs().empty())
|
||||||
|
resultTypes.push_back(batch.getOutputs()[lane].getType());
|
||||||
|
return resultTypes;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto getTaskTemplateBlock = [&](const ScheduledTask& task) -> Block& {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(task.sourceOp))
|
||||||
|
return compute.getBody().front();
|
||||||
|
return cast<SpatComputeBatch>(task.sourceOp).getBody().front();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto appendUniqueValue = [](SmallVectorImpl<Value>& values, DenseSet<Value>& seen, Value value) {
|
||||||
|
if (seen.insert(value).second)
|
||||||
|
values.push_back(value);
|
||||||
|
};
|
||||||
|
|
||||||
|
DenseMap<ComputeInstance, ScheduledTask> taskByKey;
|
||||||
|
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||||
|
SmallVector<size_t> orderedCpus;
|
||||||
|
DenseSet<size_t> seenCpus;
|
||||||
|
DenseSet<Operation*> internalInputOpsToErase;
|
||||||
|
DenseMap<Operation*, bool> isInternalInputOpCache;
|
||||||
|
size_t nextOrder = 0;
|
||||||
|
auto markCpuSeen = [&](size_t cpu) {
|
||||||
|
if (seenCpus.insert(cpu).second)
|
||||||
|
orderedCpus.push_back(cpu);
|
||||||
|
};
|
||||||
|
for (ComputeInstance scheduledInstance : analysisResult.dominanceOrderCompute) {
|
||||||
|
size_t cpu = analysisResult.computeToCpuMap.at(scheduledInstance);
|
||||||
|
ScheduledTask task {scheduledInstance,
|
||||||
|
scheduledInstance.op,
|
||||||
|
cpu,
|
||||||
|
analysisResult.computeToCpuSlotMap.lookup(scheduledInstance),
|
||||||
|
nextOrder++};
|
||||||
|
taskByKey[task.key] = task;
|
||||||
|
tasksByCpu[cpu].push_back(task);
|
||||||
|
markCpuSeen(cpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(orderedCpus);
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
||||||
|
if (lhs.slot != rhs.slot)
|
||||||
|
return lhs.slot < rhs.slot;
|
||||||
|
return lhs.order < rhs.order;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<bool(Operation*)> isInternalInputOp = [&](Operation* op) {
|
||||||
|
auto it = isInternalInputOpCache.find(op);
|
||||||
|
if (it != isInternalInputOpCache.end())
|
||||||
|
return it->second;
|
||||||
|
|
||||||
|
auto extract = dyn_cast_or_null<tensor::ExtractSliceOp>(op);
|
||||||
|
if (!extract)
|
||||||
|
return isInternalInputOpCache[op] = false;
|
||||||
|
|
||||||
|
for (Value result : extract->getResults()) {
|
||||||
|
for (Operation* user : result.getUsers()) {
|
||||||
|
if (toEraseSet.contains(user))
|
||||||
|
continue;
|
||||||
|
if (isInternalInputOp(user))
|
||||||
|
continue;
|
||||||
|
return isInternalInputOpCache[op] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isInternalInputOpCache[op] = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto collectInternalInputOps = [&](Value value) {
|
||||||
|
Operation* op = value.getDefiningOp();
|
||||||
|
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op)) {
|
||||||
|
if (isInternalInputOp(extract.getOperation()))
|
||||||
|
internalInputOpsToErase.insert(extract.getOperation());
|
||||||
|
value = extract.getSource();
|
||||||
|
op = value.getDefiningOp();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
DenseSet<Operation*> externalUsersToMove;
|
DenseSet<Operation*> externalUsersToMove;
|
||||||
auto collectExternalUsers = [&](Operation* op, auto&& collectExternalUsers) -> void {
|
auto collectExternalUsers = [&](Operation* op, auto&& collectExternalUsers) -> void {
|
||||||
if (!externalUsersToMove.insert(op).second)
|
if (!externalUsersToMove.insert(op).second)
|
||||||
@@ -1413,28 +1447,294 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DenseSet<Operation*> erasedOps;
|
DenseMap<ComputeInstance, SmallVector<SmallVector<ChannelInfo>>> remoteSendsByTask;
|
||||||
for (ComputeInstance instance : llvm::reverse(analysisResult.dominanceOrderCompute)) {
|
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||||
if (!erasedOps.insert(instance.op).second)
|
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||||
|
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
||||||
|
DenseMap<size_t, SmallVector<ProducerValueRef>> cpuExternalOutputs;
|
||||||
|
DenseMap<size_t, DenseSet<Value>> seenExternalInputsByCpu;
|
||||||
|
DenseMap<size_t, DenseSet<Value>> seenWeightsByCpu;
|
||||||
|
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
|
auto taskWeights = getTaskWeights(task);
|
||||||
|
for (Value weight : taskWeights)
|
||||||
|
appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight);
|
||||||
|
|
||||||
|
auto taskInputs = getTaskInputs(task);
|
||||||
|
auto& remoteInputs = remoteInputsByTask[task.key];
|
||||||
|
remoteInputs.resize(taskInputs.size());
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
|
auto producerRef = getProducerValueRef(input);
|
||||||
|
if (producerRef) {
|
||||||
|
collectInternalInputOps(input);
|
||||||
|
auto producerIt = taskByKey.find(producerRef->instance);
|
||||||
|
if (producerIt != taskByKey.end()) {
|
||||||
|
if (producerIt->second.cpu != cpu) {
|
||||||
|
ChannelInfo info {
|
||||||
|
nextChannelId++,
|
||||||
|
getPhysicalCoreId(producerIt->second.cpu),
|
||||||
|
getPhysicalCoreId(cpu),
|
||||||
|
};
|
||||||
|
remoteInputs[inputIndex] = info;
|
||||||
|
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||||
|
if (perResultChannels.empty())
|
||||||
|
perResultChannels.resize(getTaskOutputTypes(producerIt->second).size());
|
||||||
|
perResultChannels[producerRef->resultIndex].push_back(info);
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
Operation* oldOp = instance.op;
|
|
||||||
if (Operation* newOp = oldToNewOpMap.lookup(oldOp)) {
|
|
||||||
for (unsigned i = 0; i < oldOp->getNumResults(); ++i) {
|
|
||||||
for (auto& use : llvm::make_early_inc_range(oldOp->getResult(i).getUses())) {
|
|
||||||
Operation* useOwner = use.getOwner();
|
|
||||||
if (!toEraseSet.contains(useOwner)) {
|
|
||||||
use.assign(newOp->getResult(i));
|
|
||||||
if (!isa<func::ReturnOp>(useOwner) && useOwner->isBeforeInBlock(newOp))
|
|
||||||
collectExternalUsers(useOwner, collectExternalUsers);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input);
|
||||||
}
|
}
|
||||||
oldOp->erase();
|
|
||||||
|
auto taskOutputs = getTaskOutputValues(task);
|
||||||
|
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||||
|
bool hasExternalUser = false;
|
||||||
|
for (auto& use : output.getUses()) {
|
||||||
|
Operation* useOwner = use.getOwner();
|
||||||
|
if (toEraseSet.contains(useOwner))
|
||||||
|
continue;
|
||||||
|
hasExternalUser = true;
|
||||||
|
if (!isa<func::ReturnOp>(useOwner))
|
||||||
|
collectExternalUsers(useOwner, collectExternalUsers);
|
||||||
|
}
|
||||||
|
if (hasExternalUser)
|
||||||
|
cpuExternalOutputs[cpu].push_back({task.key, resultIndex});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func::FuncOp func = getOperation();
|
|
||||||
auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(func.getBody().front().getTerminator());
|
||||||
|
IRRewriter rewriter(&getContext());
|
||||||
|
DenseMap<size_t, CpuProgram> cpuPrograms;
|
||||||
|
DenseMap<Value, Value> oldToNewExternalValueMap;
|
||||||
|
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
SmallVector<Value> operands;
|
||||||
|
operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size());
|
||||||
|
llvm::append_range(operands, cpuWeights[cpu]);
|
||||||
|
llvm::append_range(operands, cpuExternalInputs[cpu]);
|
||||||
|
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
||||||
|
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||||
|
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||||
|
resultTypes.push_back(getTaskOutputTypes(task)[outputRef.resultIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
|
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(cpuWeights[cpu].size()), static_cast<int>(cpuExternalInputs[cpu].size())});
|
||||||
|
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(cpu)));
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
blockArgTypes.reserve(cpuExternalInputs[cpu].size());
|
||||||
|
blockArgLocs.reserve(cpuExternalInputs[cpu].size());
|
||||||
|
for (Value input : cpuExternalInputs[cpu]) {
|
||||||
|
blockArgTypes.push_back(input.getType());
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
}
|
||||||
|
Block* newBlock =
|
||||||
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
|
||||||
|
CpuProgram program;
|
||||||
|
program.op = newCompute;
|
||||||
|
program.block = newBlock;
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu]))
|
||||||
|
program.weightToIndex[weight] = weightIndex;
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
||||||
|
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
||||||
|
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
||||||
|
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||||
|
oldToNewExternalValueMap[getTaskOutputValues(task)[outputRef.resultIndex]] = newCompute.getResult(resultIndex);
|
||||||
|
}
|
||||||
|
cpuPrograms[cpu] = std::move(program);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseMap<ComputeInstance, SmallVector<Value>> producedValuesByTask;
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
CpuProgram& program = cpuPrograms[cpu];
|
||||||
|
IRRewriter cpuRewriter(&getContext());
|
||||||
|
cpuRewriter.setInsertionPointToEnd(program.block);
|
||||||
|
|
||||||
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
|
SmallVector<Value> taskInputs = getTaskInputs(task);
|
||||||
|
auto taskWeights = getTaskWeights(task);
|
||||||
|
Block& templateBlock = getTaskTemplateBlock(task);
|
||||||
|
|
||||||
|
SmallVector<Value> resolvedInputs;
|
||||||
|
resolvedInputs.reserve(taskInputs.size());
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(task.key);
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
|
auto producerRef = getProducerValueRef(input);
|
||||||
|
if (producerRef) {
|
||||||
|
auto producerIt = taskByKey.find(producerRef->instance);
|
||||||
|
if (producerIt != taskByKey.end()) {
|
||||||
|
if (producerIt->second.cpu == cpu) {
|
||||||
|
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||||
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||||
|
task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||||
|
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
||||||
|
<< " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot
|
||||||
|
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||||
|
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||||
|
auto receive =
|
||||||
|
spatial::SpatChannelReceiveOp::create(cpuRewriter,
|
||||||
|
loc,
|
||||||
|
input.getType(),
|
||||||
|
cpuRewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||||
|
cpuRewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||||
|
cpuRewriter.getI32IntegerAttr(channelInfo.targetCoreId));
|
||||||
|
resolvedInputs.push_back(receive.getResult());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(program.externalInputMap.at(input));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> taskYieldValues;
|
||||||
|
cpuRewriter.setInsertionPointToEnd(program.block);
|
||||||
|
if (isa<SpatCompute>(task.sourceOp)) {
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
|
||||||
|
mapper.map(oldArg, resolvedInputs[argIndex]);
|
||||||
|
|
||||||
|
for (Operation& op : templateBlock) {
|
||||||
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||||
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
||||||
|
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
||||||
|
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||||
|
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
|
}
|
||||||
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
||||||
|
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
||||||
|
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||||
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) {
|
||||||
|
IRMapping mapper;
|
||||||
|
if (templateBlock.getNumArguments() == 1)
|
||||||
|
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
||||||
|
|
||||||
|
for (Operation& op : templateBlock) {
|
||||||
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||||
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
||||||
|
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||||
|
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
||||||
|
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||||
|
}
|
||||||
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
||||||
|
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||||
|
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
||||||
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
producedValuesByTask[task.key] = taskYieldValues;
|
||||||
|
if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) {
|
||||||
|
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
||||||
|
if (sendInfos.empty())
|
||||||
|
continue;
|
||||||
|
Value producedValue = taskYieldValues[resultIndex];
|
||||||
|
for (const ChannelInfo& sendInfo : sendInfos)
|
||||||
|
spatial::SpatChannelSendOp::create(cpuRewriter,
|
||||||
|
loc,
|
||||||
|
cpuRewriter.getI64IntegerAttr(sendInfo.channelId),
|
||||||
|
cpuRewriter.getI32IntegerAttr(sendInfo.sourceCoreId),
|
||||||
|
cpuRewriter.getI32IntegerAttr(sendInfo.targetCoreId),
|
||||||
|
producedValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> yieldValues;
|
||||||
|
yieldValues.reserve(cpuExternalOutputs[cpu].size());
|
||||||
|
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||||
|
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||||
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||||
|
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||||
|
task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||||
|
<< " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart;
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(cpuRewriter, loc, ValueRange(yieldValues));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
||||||
|
for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
|
||||||
|
if (!toEraseSet.contains(use.getOwner()))
|
||||||
|
use.assign(newValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseSet<Operation*> allOpsToErase = toEraseSet;
|
||||||
|
for (Operation* op : internalInputOpsToErase)
|
||||||
|
allOpsToErase.insert(op);
|
||||||
|
|
||||||
|
SmallVector<Operation*> orderedOpsToErase;
|
||||||
|
for (Operation& op : func.getBody().front())
|
||||||
|
if (allOpsToErase.contains(&op))
|
||||||
|
orderedOpsToErase.push_back(&op);
|
||||||
|
for (Operation* op : llvm::reverse(orderedOpsToErase)) {
|
||||||
|
SmallVector<Operation*> remainingUsers;
|
||||||
|
for (Value result : op->getResults())
|
||||||
|
for (Operation* user : result.getUsers())
|
||||||
|
remainingUsers.push_back(user);
|
||||||
|
if (!remainingUsers.empty()) {
|
||||||
|
llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n";
|
||||||
|
llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n";
|
||||||
|
op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
|
||||||
|
llvm::errs() << "\n";
|
||||||
|
for (Operation* user : remainingUsers) {
|
||||||
|
llvm::errs() << " user: " << user->getName()
|
||||||
|
<< " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n";
|
||||||
|
user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions());
|
||||||
|
llvm::errs() << "\n";
|
||||||
|
}
|
||||||
|
op->emitOpError("still has uses during per-cpu merge cleanup");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
op->erase();
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Operation*> orderedUsersToMove;
|
SmallVector<Operation*> orderedUsersToMove;
|
||||||
for (Operation& op : func.getBody().front()) {
|
for (Operation& op : func.getBody().front()) {
|
||||||
if (&op == returnOp.getOperation())
|
if (&op == returnOp.getOperation())
|
||||||
@@ -1445,9 +1745,16 @@ public:
|
|||||||
for (Operation* op : orderedUsersToMove)
|
for (Operation* op : orderedUsersToMove)
|
||||||
op->moveBefore(returnOp);
|
op->moveBefore(returnOp);
|
||||||
|
|
||||||
sinkChannelsIntoComputes(func, nextChannelId);
|
|
||||||
rebatchEquivalentComputes(func, nextChannelId);
|
rebatchEquivalentComputes(func, nextChannelId);
|
||||||
compactScalarChannelRuns(func, nextChannelId);
|
compactScalarChannelRuns(func, nextChannelId);
|
||||||
|
compactBatchChannelRuns(func);
|
||||||
|
compactRegularOpRuns(func);
|
||||||
|
compactRowWiseWvmmRuns(func);
|
||||||
|
if (!sortTopologically(&func.getBody().front())) {
|
||||||
|
func.emitOpError("failed to topologically order merged Spatial IR");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||||
generateReport(func, "spatial1_dcp_merged_report", analysisResult.cpuToLastComputeMap.size());
|
generateReport(func, "spatial1_dcp_merged_report", analysisResult.cpuToLastComputeMap.size());
|
||||||
}
|
}
|
||||||
@@ -1477,9 +1784,9 @@ private:
|
|||||||
Value resolvedInput = input;
|
Value resolvedInput = input;
|
||||||
if (auto producerRef = getProducerValueRef(input)) {
|
if (auto producerRef = getProducerValueRef(input)) {
|
||||||
LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance);
|
LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance);
|
||||||
auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu);
|
auto [channelVal, isChannel, channelInfo] =
|
||||||
(void) isChannel;
|
producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu);
|
||||||
(void) channelVal;
|
if (isChannel)
|
||||||
resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter,
|
resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
input.getType(),
|
input.getType(),
|
||||||
@@ -1487,6 +1794,8 @@ private:
|
|||||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId))
|
rewriter.getI32IntegerAttr(channelInfo.targetCoreId))
|
||||||
.getResult();
|
.getResult();
|
||||||
|
else
|
||||||
|
resolvedInput = channelVal;
|
||||||
}
|
}
|
||||||
|
|
||||||
newComputeOperands.push_back(resolvedInput);
|
newComputeOperands.push_back(resolvedInput);
|
||||||
@@ -1532,7 +1841,8 @@ private:
|
|||||||
uint32_t firstLane,
|
uint32_t firstLane,
|
||||||
uint32_t laneCount,
|
uint32_t laneCount,
|
||||||
size_t currentCpu,
|
size_t currentCpu,
|
||||||
const DCPAnalysisResult& analysisResult) {
|
const DCPAnalysisResult& analysisResult,
|
||||||
|
std::optional<uint64_t> rebatchPhase = std::nullopt) {
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
auto loc = func.getLoc();
|
auto loc = func.getLoc();
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
@@ -1547,15 +1857,17 @@ private:
|
|||||||
|
|
||||||
for (uint32_t lane = firstLane; lane < firstLane + laneCount; ++lane) {
|
for (uint32_t lane = firstLane; lane < firstLane + laneCount; ++lane) {
|
||||||
weights.push_back(batch.getWeights()[lane]);
|
weights.push_back(batch.getWeights()[lane]);
|
||||||
|
if (!batch.getOutputs().empty())
|
||||||
resultTypes.push_back(batch.getOutputs()[lane].getType());
|
resultTypes.push_back(batch.getOutputs()[lane].getType());
|
||||||
|
|
||||||
|
if (!batch.getInputs().empty()) {
|
||||||
Value input = batch.getInputs()[lane];
|
Value input = batch.getInputs()[lane];
|
||||||
Value resolvedInput = input;
|
Value resolvedInput = input;
|
||||||
if (auto producerRef = getProducerValueRef(input)) {
|
if (auto producerRef = getProducerValueRef(input)) {
|
||||||
LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance);
|
LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance);
|
||||||
auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu);
|
auto [channelVal, isChannel, channelInfo] =
|
||||||
(void) isChannel;
|
producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu);
|
||||||
(void) channelVal;
|
if (isChannel)
|
||||||
resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter,
|
resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
input.getType(),
|
input.getType(),
|
||||||
@@ -1563,9 +1875,12 @@ private:
|
|||||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId))
|
rewriter.getI32IntegerAttr(channelInfo.targetCoreId))
|
||||||
.getResult();
|
.getResult();
|
||||||
|
else
|
||||||
|
resolvedInput = channelVal;
|
||||||
}
|
}
|
||||||
inputs.push_back(resolvedInput);
|
inputs.push_back(resolvedInput);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Block& templateBlock = batch.getBody().front();
|
Block& templateBlock = batch.getBody().front();
|
||||||
if (laneCount == 1) {
|
if (laneCount == 1) {
|
||||||
@@ -1574,10 +1889,16 @@ private:
|
|||||||
compute.getProperties().setOperandSegmentSizes(
|
compute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||||
compute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(currentCpu)));
|
compute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(currentCpu)));
|
||||||
|
if (rebatchPhase)
|
||||||
|
compute->setAttr(kRebatchPhaseAttrName, rewriter.getI64IntegerAttr(*rebatchPhase));
|
||||||
|
|
||||||
auto* newBlock = rewriter.createBlock(
|
SmallVector<Type> blockArgTypes;
|
||||||
&compute.getBody(), compute.getBody().end(), TypeRange {templateBlock.getArgument(0).getType()}, {loc});
|
if (templateBlock.getNumArguments() == 1)
|
||||||
|
blockArgTypes.push_back(templateBlock.getArgument(0).getType());
|
||||||
|
SmallVector<Location> blockArgLocs(templateBlock.getNumArguments(), loc);
|
||||||
|
auto* newBlock = rewriter.createBlock(&compute.getBody(), compute.getBody().end(), blockArgTypes, blockArgLocs);
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
|
if (templateBlock.getNumArguments() == 1)
|
||||||
mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0));
|
mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0));
|
||||||
rewriter.setInsertionPointToEnd(newBlock);
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
for (Operation& op : templateBlock)
|
for (Operation& op : templateBlock)
|
||||||
@@ -1599,14 +1920,16 @@ private:
|
|||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
||||||
ValueRange(weights),
|
ValueRange(weights),
|
||||||
ValueRange(inputs));
|
ValueRange(inputs));
|
||||||
rebatched->setAttr(onnx_mlir::kCoreIdAttrName,
|
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName,
|
||||||
rewriter.getDenseI32ArrayAttr(SmallVector<int32_t>(laneCount, getPhysicalCoreId(currentCpu))));
|
rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount)));
|
||||||
|
|
||||||
auto* newBlock = rewriter.createBlock(&rebatched.getBody(),
|
SmallVector<Type> blockArgTypes;
|
||||||
rebatched.getBody().end(),
|
if (templateBlock.getNumArguments() == 1)
|
||||||
TypeRange {templateBlock.getArgument(0).getType()},
|
blockArgTypes.push_back(templateBlock.getArgument(0).getType());
|
||||||
SmallVector<Location>(1, loc));
|
SmallVector<Location> blockArgLocs(templateBlock.getNumArguments(), loc);
|
||||||
|
auto* newBlock = rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), blockArgTypes, blockArgLocs);
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
|
if (templateBlock.getNumArguments() == 1)
|
||||||
mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0));
|
mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0));
|
||||||
rewriter.setInsertionPointToEnd(newBlock);
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
for (Operation& op : templateBlock) {
|
for (Operation& op : templateBlock) {
|
||||||
@@ -1621,7 +1944,7 @@ private:
|
|||||||
ComputeValueResults results;
|
ComputeValueResults results;
|
||||||
results.outerValues.assign(rebatched->result_begin(), rebatched->result_end());
|
results.outerValues.assign(rebatched->result_begin(), rebatched->result_end());
|
||||||
results.innerValues = results.outerValues;
|
results.innerValues = results.outerValues;
|
||||||
if (results.innerValues.empty())
|
if (results.innerValues.empty() && yieldOp.getNumOperands() == 1)
|
||||||
results.innerValues.push_back(yieldOp.getOperand(0));
|
results.innerValues.push_back(yieldOp.getOperand(0));
|
||||||
newComputeNodeResults.insert({
|
newComputeNodeResults.insert({
|
||||||
ComputeInstance {batch.getOperation(), firstLane, laneCount},
|
ComputeInstance {batch.getOperation(), firstLane, laneCount},
|
||||||
@@ -1658,7 +1981,7 @@ private:
|
|||||||
channelInfo, insertVal};
|
channelInfo, insertVal};
|
||||||
return ret;
|
return ret;
|
||||||
};
|
};
|
||||||
return LazyInsertComputeResult(computeValueResults, insertNew);
|
return LazyInsertComputeResult(computeValueResults, producerCpu, insertNew);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,577 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "RegularOpCompaction.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
enum class RegularStepKind {
|
||||||
|
Wvmm,
|
||||||
|
VAddLhs,
|
||||||
|
VAddRhs,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RegularStep {
|
||||||
|
RegularStepKind kind;
|
||||||
|
int32_t weightIndex = 0;
|
||||||
|
Value invariantOperand;
|
||||||
|
Type resultType;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RegularChunk {
|
||||||
|
Operation* startOp = nullptr;
|
||||||
|
SmallVector<Operation*> ops;
|
||||||
|
SmallVector<RegularStep> steps;
|
||||||
|
Value input;
|
||||||
|
Value output;
|
||||||
|
};
|
||||||
|
|
||||||
|
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
|
||||||
|
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
|
||||||
|
&& lhs.resultType == rhs.resultType;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChunk& rhs) {
|
||||||
|
if (lhs.input.getType() != rhs.input.getType() || lhs.output.getType() != rhs.output.getType()
|
||||||
|
|| lhs.steps.size() != rhs.steps.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvm::all_of(llvm::zip_equal(lhs.steps, rhs.steps),
|
||||||
|
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) {
|
||||||
|
RegularChunk chunk;
|
||||||
|
chunk.startOp = startOp.getOperation();
|
||||||
|
chunk.input = startOp.getInput();
|
||||||
|
chunk.output = startOp.getOutput();
|
||||||
|
chunk.ops.push_back(startOp.getOperation());
|
||||||
|
chunk.steps.push_back(
|
||||||
|
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
|
||||||
|
|
||||||
|
Value currentValue = startOp.getOutput();
|
||||||
|
while (currentValue.hasOneUse()) {
|
||||||
|
Operation* user = *currentValue.getUsers().begin();
|
||||||
|
if (user->getBlock() != startOp->getBlock())
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto vaddOp = dyn_cast<spatial::SpatVAddOp>(user);
|
||||||
|
if (!vaddOp)
|
||||||
|
break;
|
||||||
|
|
||||||
|
if (vaddOp.getLhs() == currentValue)
|
||||||
|
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()});
|
||||||
|
else if (vaddOp.getRhs() == currentValue)
|
||||||
|
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()});
|
||||||
|
else
|
||||||
|
break;
|
||||||
|
|
||||||
|
chunk.ops.push_back(vaddOp);
|
||||||
|
chunk.output = vaddOp.getOutput();
|
||||||
|
currentValue = vaddOp.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunk;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) {
|
||||||
|
auto* block = rewriter.createBlock(
|
||||||
|
&mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()});
|
||||||
|
rewriter.setInsertionPointToEnd(block);
|
||||||
|
|
||||||
|
IRMapping mapping;
|
||||||
|
mapping.map(anchorChunk.input, block->getArgument(0));
|
||||||
|
|
||||||
|
for (Operation* op : anchorChunk.ops) {
|
||||||
|
Operation* cloned = rewriter.clone(*op, mapping);
|
||||||
|
for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults()))
|
||||||
|
mapping.map(oldResult, newResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
spatial::SpatYieldOp::create(
|
||||||
|
rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||||
|
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||||
|
const RegularChunk& anchorChunk = run.front();
|
||||||
|
|
||||||
|
SmallVector<Value> inputs;
|
||||||
|
SmallVector<Type> outputTypes;
|
||||||
|
inputs.reserve(run.size());
|
||||||
|
outputTypes.reserve(run.size());
|
||||||
|
for (const RegularChunk& chunk : run) {
|
||||||
|
inputs.push_back(chunk.input);
|
||||||
|
outputTypes.push_back(chunk.output.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||||
|
auto mapOp =
|
||||||
|
spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs));
|
||||||
|
buildRegularMapBody(mapOp, anchorChunk, rewriter);
|
||||||
|
|
||||||
|
for (auto [index, chunk] : llvm::enumerate(run)) {
|
||||||
|
Value output = chunk.output;
|
||||||
|
output.replaceAllUsesWith(mapOp.getResult(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Operation*> opsToErase;
|
||||||
|
for (const RegularChunk& chunk : run)
|
||||||
|
llvm::append_range(opsToErase, chunk.ops);
|
||||||
|
for (Operation* op : llvm::reverse(opsToErase))
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
|
Block& block = compute.getBody().front();
|
||||||
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||||
|
if (receiveOp) {
|
||||||
|
SmallVector<spatial::SpatChannelReceiveOp> run;
|
||||||
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
|
auto runIt = it;
|
||||||
|
while (runIt != block.end()) {
|
||||||
|
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
||||||
|
if (!current || current.getOutput().getType() != outputType)
|
||||||
|
break;
|
||||||
|
run.push_back(current);
|
||||||
|
++runIt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (run.size() > 1) {
|
||||||
|
struct ReceiveEntry {
|
||||||
|
spatial::SpatChannelReceiveOp op;
|
||||||
|
size_t originalIndex = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
uint64_t channelId = 0;
|
||||||
|
};
|
||||||
|
SmallVector<ReceiveEntry> sortedEntries;
|
||||||
|
sortedEntries.reserve(run.size());
|
||||||
|
for (auto [originalIndex, op] : llvm::enumerate(run))
|
||||||
|
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||||
|
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
||||||
|
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||||
|
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||||
|
});
|
||||||
|
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
SmallVector<Type> outputTypes;
|
||||||
|
channelIds.reserve(sortedEntries.size());
|
||||||
|
sourceCoreIds.reserve(sortedEntries.size());
|
||||||
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
|
outputTypes.reserve(sortedEntries.size());
|
||||||
|
for (ReceiveEntry& entry : sortedEntries) {
|
||||||
|
(void) entry;
|
||||||
|
channelIds.push_back(nextChannelId++);
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
|
outputTypes.push_back(entry.op.getOutput().getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(run.front());
|
||||||
|
auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter,
|
||||||
|
run.front().getLoc(),
|
||||||
|
TypeRange(outputTypes),
|
||||||
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
|
for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries))
|
||||||
|
entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex));
|
||||||
|
for (auto op : run)
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
it = compactReceive->getIterator();
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
||||||
|
if (sendOp) {
|
||||||
|
SmallVector<spatial::SpatChannelSendOp> run;
|
||||||
|
Type inputType = sendOp.getInput().getType();
|
||||||
|
auto runIt = it;
|
||||||
|
while (runIt != block.end()) {
|
||||||
|
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
||||||
|
if (!current || current.getInput().getType() != inputType)
|
||||||
|
break;
|
||||||
|
run.push_back(current);
|
||||||
|
++runIt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (run.size() > 1) {
|
||||||
|
struct SendEntry {
|
||||||
|
spatial::SpatChannelSendOp op;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
uint64_t channelId = 0;
|
||||||
|
};
|
||||||
|
SmallVector<SendEntry> sortedEntries;
|
||||||
|
sortedEntries.reserve(run.size());
|
||||||
|
for (auto op : run)
|
||||||
|
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||||
|
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
||||||
|
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||||
|
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||||
|
});
|
||||||
|
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
SmallVector<Value> inputs;
|
||||||
|
channelIds.reserve(sortedEntries.size());
|
||||||
|
sourceCoreIds.reserve(sortedEntries.size());
|
||||||
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
|
inputs.reserve(sortedEntries.size());
|
||||||
|
for (SendEntry& entry : sortedEntries) {
|
||||||
|
(void) entry;
|
||||||
|
channelIds.push_back(nextChannelId++);
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
|
inputs.push_back(entry.op.getInput());
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(run.front());
|
||||||
|
spatial::SpatChannelSendManyOp::create(rewriter,
|
||||||
|
run.front().getLoc(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||||
|
ValueRange(inputs));
|
||||||
|
for (auto op : run)
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
it = runIt;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
|
||||||
|
for (auto batch : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||||
|
Block& block = batch.getBody().front();
|
||||||
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
||||||
|
if (receiveOp) {
|
||||||
|
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
|
||||||
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
|
auto runIt = it;
|
||||||
|
while (runIt != block.end()) {
|
||||||
|
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
|
||||||
|
if (!current || current.getOutput().getType() != outputType)
|
||||||
|
break;
|
||||||
|
run.push_back(current);
|
||||||
|
++runIt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (run.size() > 1) {
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
SmallVector<Type> outputTypes;
|
||||||
|
outputTypes.reserve(run.size());
|
||||||
|
for (auto op : run) {
|
||||||
|
llvm::append_range(channelIds, op.getChannelIds());
|
||||||
|
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||||
|
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||||
|
outputTypes.push_back(op.getOutput().getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(run.front());
|
||||||
|
auto compactReceive =
|
||||||
|
spatial::SpatChannelReceiveManyBatchOp::create(rewriter,
|
||||||
|
run.front().getLoc(),
|
||||||
|
TypeRange(outputTypes),
|
||||||
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
|
for (auto [index, op] : llvm::enumerate(run))
|
||||||
|
op.getOutput().replaceAllUsesWith(compactReceive.getResult(index));
|
||||||
|
for (auto op : run)
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
it = compactReceive->getIterator();
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
||||||
|
if (sendOp) {
|
||||||
|
SmallVector<spatial::SpatChannelSendBatchOp> run;
|
||||||
|
Type inputType = sendOp.getInput().getType();
|
||||||
|
auto runIt = it;
|
||||||
|
while (runIt != block.end()) {
|
||||||
|
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
|
||||||
|
if (!current || current.getInput().getType() != inputType)
|
||||||
|
break;
|
||||||
|
run.push_back(current);
|
||||||
|
++runIt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (run.size() > 1) {
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
SmallVector<Value> inputs;
|
||||||
|
inputs.reserve(run.size());
|
||||||
|
for (auto op : run) {
|
||||||
|
llvm::append_range(channelIds, op.getChannelIds());
|
||||||
|
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||||
|
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||||
|
inputs.push_back(op.getInput());
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(run.front());
|
||||||
|
spatial::SpatChannelSendManyBatchOp::create(rewriter,
|
||||||
|
run.front().getLoc(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||||
|
ValueRange(inputs));
|
||||||
|
for (auto op : run)
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
it = runIt;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
|
||||||
|
auto compactInBlock = [&](Block& block) {
|
||||||
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
|
auto startOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
||||||
|
if (!startOp) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto anchorChunk = analyzeRegularChunk(startOp);
|
||||||
|
if (failed(anchorChunk)) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<RegularChunk> run {*anchorChunk};
|
||||||
|
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||||
|
while (runIt != block.end()) {
|
||||||
|
auto candidateStart = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
||||||
|
if (!candidateStart)
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto candidateChunk = analyzeRegularChunk(candidateStart);
|
||||||
|
if (failed(candidateChunk) || !areEquivalentRegularChunks(*anchorChunk, *candidateChunk))
|
||||||
|
break;
|
||||||
|
|
||||||
|
run.push_back(*candidateChunk);
|
||||||
|
runIt = std::next(runIt, static_cast<std::ptrdiff_t>(candidateChunk->ops.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (run.size() <= 1) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
compactRegularChunkRun(rewriter, run);
|
||||||
|
it = runIt;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
||||||
|
compactInBlock(compute.getBody().front());
|
||||||
|
for (auto batch : funcOp.getOps<spatial::SpatComputeBatch>())
|
||||||
|
compactInBlock(batch.getBody().front());
|
||||||
|
}
|
||||||
|
|
||||||
|
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
|
Block& block = compute.getBody().front();
|
||||||
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
|
auto wvmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
||||||
|
if (!wvmmOp) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto extractRowsOp = wvmmOp.getInput().getDefiningOp<spatial::SpatExtractRowsOp>();
|
||||||
|
auto rowResult = dyn_cast<OpResult>(wvmmOp.getInput());
|
||||||
|
auto outputType = dyn_cast<RankedTensorType>(wvmmOp.getOutput().getType());
|
||||||
|
if (!extractRowsOp || !rowResult || rowResult.getOwner() != extractRowsOp || !outputType
|
||||||
|
|| !outputType.hasStaticShape() || outputType.getRank() != 2 || outputType.getShape()[0] != 1) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<spatial::SpatWeightedVMMOp> run;
|
||||||
|
auto runIt = it;
|
||||||
|
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||||
|
while (runIt != block.end()) {
|
||||||
|
auto current = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
||||||
|
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||||
|
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||||
|
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||||
|
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
||||||
|
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
||||||
|
break;
|
||||||
|
|
||||||
|
run.push_back(current);
|
||||||
|
++expectedRow;
|
||||||
|
++runIt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (run.size() <= 1) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!run.front().getOutput().hasOneUse()) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto concatUse = run.front().getOutput().getUses().begin();
|
||||||
|
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
||||||
|
if (!concatOp) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned concatStartIndex = concatUse->getOperandNumber();
|
||||||
|
bool validConcatRun = true;
|
||||||
|
for (auto [index, op] : llvm::enumerate(run)) {
|
||||||
|
if (!op.getOutput().hasOneUse()) {
|
||||||
|
validConcatRun = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
OpOperand& use = *op.getOutput().getUses().begin();
|
||||||
|
if (use.getOwner() != concatOp || use.getOperandNumber() != concatStartIndex + index) {
|
||||||
|
validConcatRun = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!validConcatRun) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(wvmmOp.getInput().getType());
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||||
|
if (!inputType || !sourceType || !inputType.hasStaticShape() || !sourceType.hasStaticShape()) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t inputCols = inputType.getShape()[1];
|
||||||
|
int64_t outputCols = outputType.getShape()[1];
|
||||||
|
if (ShapedType::isDynamic(inputCols) || ShapedType::isDynamic(outputCols)) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||||
|
int64_t runLength = static_cast<int64_t>(run.size());
|
||||||
|
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(run.front());
|
||||||
|
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
|
||||||
|
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
|
||||||
|
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
|
||||||
|
auto packedInit =
|
||||||
|
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||||
|
auto loop =
|
||||||
|
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
Block* loopBlock = loop.getBody();
|
||||||
|
rewriter.setInsertionPointToStart(loopBlock);
|
||||||
|
Value iv = loopBlock->getArgument(0);
|
||||||
|
Value acc = loopBlock->getArgument(1);
|
||||||
|
|
||||||
|
Value sourceRow = iv;
|
||||||
|
if (firstRow != 0) {
|
||||||
|
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
|
||||||
|
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
||||||
|
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
||||||
|
run.front().getLoc(),
|
||||||
|
inputType,
|
||||||
|
extractRowsOp.getInput(),
|
||||||
|
extractOffsets,
|
||||||
|
extractSizes,
|
||||||
|
extractStrides);
|
||||||
|
auto loopWvmm = spatial::SpatWeightedVMMOp::create(
|
||||||
|
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||||
|
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto inserted = tensor::InsertSliceOp::create(
|
||||||
|
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||||
|
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> newConcatInputs;
|
||||||
|
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
|
||||||
|
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||||
|
if (operandIndex == concatStartIndex)
|
||||||
|
newConcatInputs.push_back(loop.getResult(0));
|
||||||
|
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
|
||||||
|
newConcatInputs.push_back(operand);
|
||||||
|
}
|
||||||
|
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
||||||
|
for (auto op : run)
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
it = loop->getIterator();
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||||
|
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
||||||
|
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
||||||
|
void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
@@ -37,8 +38,12 @@ def _parse_pim_pass_timings(output_text):
|
|||||||
return pass_timings
|
return pass_timings
|
||||||
|
|
||||||
|
|
||||||
|
def _format_command(cmd):
|
||||||
|
return shlex.join(str(arg) for arg in cmd)
|
||||||
|
|
||||||
|
|
||||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||||
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None):
|
||||||
# Define the arguments, with the possibility to set crossbar size and count
|
# Define the arguments, with the possibility to set crossbar size and count
|
||||||
args = [
|
args = [
|
||||||
network_path,
|
network_path,
|
||||||
@@ -51,10 +56,18 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
|||||||
f"--crossbar-count={crossbar_count}",
|
f"--crossbar-count={crossbar_count}",
|
||||||
"--enable-timing",
|
"--enable-timing",
|
||||||
]
|
]
|
||||||
|
if core_count is not None:
|
||||||
|
args.append(f"--core-count={core_count}")
|
||||||
|
|
||||||
|
cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args]
|
||||||
|
if reporter is not None:
|
||||||
|
reporter.log(f" Raptor command: {_format_command(cmd)}")
|
||||||
|
else:
|
||||||
|
print(f"Raptor command: {_format_command(cmd)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output_text = run_command_with_reporter(
|
output_text = run_command_with_reporter(
|
||||||
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
cmd,
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
reporter=reporter,
|
reporter=reporter,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import shlex
|
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
@@ -11,12 +10,6 @@ from validate_one import ProgressReporter, clean_workspace_artifacts, validate_n
|
|||||||
from raptor import PIM_PASS_LABELS
|
from raptor import PIM_PASS_LABELS
|
||||||
|
|
||||||
|
|
||||||
def format_command(cmd):
|
|
||||||
if isinstance(cmd, (list, tuple)):
|
|
||||||
return shlex.join(str(arg) for arg in cmd)
|
|
||||||
return str(cmd)
|
|
||||||
|
|
||||||
|
|
||||||
def format_return_status(returncode):
|
def format_return_status(returncode):
|
||||||
if returncode < 0:
|
if returncode < 0:
|
||||||
signal_num = -returncode
|
signal_num = -returncode
|
||||||
@@ -34,8 +27,6 @@ def print_validation_error(reporter, rel, exc):
|
|||||||
file=sys.stderr, flush=True)
|
file=sys.stderr, flush=True)
|
||||||
if isinstance(exc, subprocess.CalledProcessError):
|
if isinstance(exc, subprocess.CalledProcessError):
|
||||||
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
print(format_return_status(exc.returncode), file=sys.stderr, flush=True)
|
||||||
print("Retry command:", file=sys.stderr, flush=True)
|
|
||||||
print(format_command(exc.cmd), file=sys.stderr, flush=True)
|
|
||||||
else:
|
else:
|
||||||
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
|
||||||
print("=" * 72, file=sys.stderr, flush=True)
|
print("=" * 72, file=sys.stderr, flush=True)
|
||||||
@@ -65,6 +56,8 @@ def main():
|
|||||||
ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.")
|
ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.")
|
||||||
ap.add_argument("--crossbar-size", type=int, default=64)
|
ap.add_argument("--crossbar-size", type=int, default=64)
|
||||||
ap.add_argument("--crossbar-count", type=int, default=8)
|
ap.add_argument("--crossbar-count", type=int, default=8)
|
||||||
|
ap.add_argument("--core-count", type=int, default=None,
|
||||||
|
help="Core count to pass to Raptor. If omitted, Raptor uses its default.")
|
||||||
ap.add_argument("--clean", action="store_true",
|
ap.add_argument("--clean", action="store_true",
|
||||||
help="Remove generated validation artifacts under each model workspace and exit.")
|
help="Remove generated validation artifacts under each model workspace and exit.")
|
||||||
a = ap.parse_args()
|
a = ap.parse_args()
|
||||||
@@ -114,7 +107,7 @@ def main():
|
|||||||
try:
|
try:
|
||||||
result = validate_network(
|
result = validate_network(
|
||||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, core_count=a.core_count,
|
||||||
threshold=a.threshold,
|
threshold=a.threshold,
|
||||||
reporter=reporter,
|
reporter=reporter,
|
||||||
model_index=index,
|
model_index=index,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import argparse
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -258,7 +257,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
|
|||||||
|
|
||||||
|
|
||||||
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
||||||
simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3,
|
simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3,
|
||||||
reporter=None, model_index=1, model_total=1):
|
reporter=None, model_index=1, model_total=1):
|
||||||
network_onnx_path = Path(network_onnx_path).resolve()
|
network_onnx_path = Path(network_onnx_path).resolve()
|
||||||
raptor_path = Path(raptor_path).resolve()
|
raptor_path = Path(raptor_path).resolve()
|
||||||
@@ -313,7 +312,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||||
pim_pass_timings = compile_with_raptor(
|
pim_pass_timings = compile_with_raptor(
|
||||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||||
crossbar_size, crossbar_count,
|
crossbar_size, crossbar_count, core_count=core_count,
|
||||||
cwd=raptor_dir, reporter=reporter)
|
cwd=raptor_dir, reporter=reporter)
|
||||||
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}")
|
||||||
reporter.advance()
|
reporter.advance()
|
||||||
@@ -350,18 +349,3 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
reporter.log("=" * 72)
|
reporter.log("=" * 72)
|
||||||
if owns_reporter:
|
if owns_reporter:
|
||||||
reporter.finish()
|
reporter.finish()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
ap = argparse.ArgumentParser()
|
|
||||||
ap.add_argument("--network-onnx", required=True)
|
|
||||||
ap.add_argument("--raptor-path", required=True)
|
|
||||||
ap.add_argument("--onnx-include-dir", required=True)
|
|
||||||
a = ap.parse_args()
|
|
||||||
|
|
||||||
simulator_dir = Path(__file__).parent.resolve() / ".." / "backend-simulators" / "pim" / "pim-simulator"
|
|
||||||
|
|
||||||
passed = validate_network(
|
|
||||||
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
|
|
||||||
)
|
|
||||||
raise SystemExit(0 if passed.passed else 1)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user