refactorone
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-20 19:06:41 +02:00
parent f56c4159b5
commit a50e77ff38
50 changed files with 3420 additions and 1187 deletions
+40
View File
@@ -145,6 +145,46 @@ validate.py \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000 --crossbar-size 2048 --crossbar-count 256 --core-count 1000
``` ```
Each validation run writes debugging artifacts into the benchmark's workspace
directory (for example `validation/operations/gemm/small/`):
- `inputs/` — generated input CSVs used for the run.
- `outputs/` — reference outputs dumped by the native ONNX runner.
- `raptor/` — compiler artifacts:
`*.onnx.mlir`, `dialects/spatial0.mlir`, `dialects/spatial1_dcp_merged.mlir`,
`dialects/pim0.mlir`, `dialects/pim1_buff.mlir`, `dialects/pim2_coalesced.mlir`,
`dialects/pim3_folded.mlir`, `dialects/pim4_materialized.mlir`,
`pim/config.json`, `pim/core_*.pim`, `pim/memory.bin`, and reports under
`raptor/reports/` such as `dcp_merge_report.txt`,
`memory_report.txt`, and `static_memory_coalescing_report.txt`.
- `runner/` — generated reference runner source, build tree, and shared library.
- `simulation/out.bin` — raw simulator output dump used for output comparison.
That means you usually do not need to rerun standalone `--EmitSpatial` or
`--EmitPim` commands while debugging validation failures: the per-pass dialect
dumps are already available under `raptor/dialects/`.
The validator does not currently expose a simulator tracing flag, but once a
validation has produced `raptor/pim/` you can rerun the simulator manually with
tracing enabled:
```bash
cd backend-simulators/pim/pim-simulator
cargo run --no-default-features --features tracing --release \
--package pim-simulator --bin pim-simulator -- \
-f /path/to/workspace/raptor/pim \
-o /path/to/workspace/simulation/out.bin \
-d <addr0>,<size0>,<addr1>,<size1>,...
```
With `--features tracing`, the simulator writes per-core traces as
`simulation/TraceCore0`, `simulation/TraceCore1`, ... next to `simulation/out.bin`.
The validator normally computes the `-d` dump ranges from `raptor/pim/config.json`
and the model output shapes. If you need a clean slate before rerunning, use:
```bash
validate.py --clean
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`. Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`, Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`, `gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
+1
View File
@@ -1,5 +1,6 @@
add_pim_library(OMPimCommon add_pim_library(OMPimCommon
IR/AddressAnalysis.cpp IR/AddressAnalysis.cpp
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp IR/ShapeUtils.cpp
+46
View File
@@ -1,5 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
@@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
} }
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge); llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
if (indices.size() != static_cast<size_t>(globalType.getRank()))
return mlir::failure();
auto strides = computeRowMajorStrides(globalType.getShape());
int64_t linearIndex = linearizeIndex(indices, strides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge); value = resolveAlias(value, knowledge);
@@ -126,6 +169,9 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs)); return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
} }
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return resolveConstantGlobalLoad(loadOp, knowledge);
return mlir::failure(); return mlir::failure();
} }
+62
View File
@@ -0,0 +1,62 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Block* getHostConstantBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
for (Operation* current = anchorOp; current; current = current->getParentOp())
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
return current->getBlock();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
}
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
}
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
}
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
}
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
}
} // namespace onnx_mlir
+28
View File
@@ -0,0 +1,28 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h"
namespace onnx_mlir {
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
mlir::Attribute value,
mlir::Type type,
mlir::OperationFolder& folder);
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
} // namespace onnx_mlir
+4
View File
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
@@ -30,6 +31,9 @@ walkPimCoreBlock(mlir::Block& block,
for (mlir::Operation& op : block) { for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op)) if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue; continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) { if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front(); mlir::Block& loopBody = forOp.getRegion().front();
+23 -14
View File
@@ -21,12 +21,13 @@ namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy> template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
bool found = false; bool found = false;
parentOp.walk([&](mlir::Operation* op) { parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op)) if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex; found |= mvmOp.getWeight() == weightArg;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op)) else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex; found |= vmmOp.getWeight() == weightArg;
}); });
return found; return found;
} }
@@ -35,13 +36,18 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) { void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights(); auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited; llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) { auto walkWeight = [&](mlir::Value weight) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second) for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
if (parentOp.getWeightArgument(weightIndex) != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex)); callback(parentOp->getOpOperand(weightIndex));
break;
}
}; };
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
} }
} // namespace } // namespace
@@ -90,18 +96,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
assert(root && "expected valid root op"); assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) { root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) { coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights(); for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
unsigned weightIndex = vmmOp.getWeightIndex(); if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex)); callback(coreOp->getOpOperand(weightIndex));
break;
}
}); });
}); });
root->walk([&](pim::PimCoreBatchOp coreBatchOp) { root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights(); coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
for (auto weight : weights) for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
for (mlir::OpOperand& use : weight.getUses()) if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
if (use.getOwner() == coreBatchOp.getOperation()) callback(coreBatchOp->getOpOperand(weightIndex));
callback(use); break;
}
});
}); });
} }
+1
View File
@@ -12,6 +12,7 @@
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp" #include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
+83 -60
View File
@@ -1,7 +1,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
@@ -24,113 +28,132 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
return laneCoreIds; return laneCoreIds;
} }
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) { static void cloneScalarizedLaneBody(OpBuilder& builder,
IRRewriter rewriter(scalarCore.getContext()); pim::PimCoreBatchOp coreBatchOp,
SmallVector<Operation*> batchOps; unsigned lane,
scalarCore.walk([&](Operation* op) { OperationFolder& constantFolder) {
if (isa<pim::PimSendBatchOp, Block& oldBlock = coreBatchOp.getBody().front();
pim::PimSendTensorBatchOp, size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
pim::PimReceiveBatchOp, size_t weightCount = coreBatchOp.getWeights().size();
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
for (Operation* op : batchOps) { IRMapping mapper;
rewriter.setInsertionPoint(op); for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
if (blockArg.getType().isIndex()) {
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder));
continue;
}
if (argIndex <= weightCount) {
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]);
continue;
}
size_t inputIndex = argIndex - 1 - weightCount;
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
}
for (Operation& op : oldBlock) {
if (isa<pim::PimHaltOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) { if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter, Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
pim::PimSendOp::create(
builder,
sendBatchOp.getLoc(), sendBatchOp.getLoc(),
sendBatchOp.getInput(), mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(), sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane])); getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
rewriter.eraseOp(op);
continue; continue;
} }
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) { if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create( pim::PimSendTensorOp::create(
rewriter, builder,
sendTensorBatchOp.getLoc(), sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(), mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane))); builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
continue; continue;
} }
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) { if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive = Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
pim::PimReceiveOp::create(rewriter, auto scalarReceive = pim::PimReceiveOp::create(
builder,
receiveBatchOp.getLoc(), receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(), receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(), mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(), receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane])); getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
rewriter.replaceOp(op, scalarReceive->getResults()); mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue; continue;
} }
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) { if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create( auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter, builder,
receiveTensorBatchOp.getLoc(), receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(), receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(), mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane))); builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults()); mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
continue; continue;
} }
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op); if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter, auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
builder,
memcpBatchOp.getLoc(), memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(), memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(), getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
memcpBatchOp.getHostSource(), getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
memcpBatchOp.getDeviceTargetOffsetAttr(), mapper.lookup(memcpBatchOp.getDeviceTarget()),
memcpBatchOp.getHostSourceOffsetAttr(), mapper.lookup(memcpBatchOp.getHostSource()),
memcpBatchOp.getSizeAttr()); memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults()); mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
} }
} }
} // namespace } // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
unsigned lane, ArrayRef<unsigned> lanes,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) { llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
assert(!lanes.empty() && "expected at least one batch lane");
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc()); OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext()); OpBuilder builder(scratchModule->getContext());
OperationFolder constantFolder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody()); builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount()); SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp); auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create( int32_t coreId = coreIds[lanes.front()];
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane])); for (unsigned lane : lanes)
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
auto scalarCore =
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) { for (unsigned lane : lanes)
Operation* cloned = builder.clone(op, mapper); cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back())) if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc()); pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore); return callback(scalarCore);
} }
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
}
} // namespace onnx_mlir } // namespace onnx_mlir
+3
View File
@@ -9,5 +9,8 @@ namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane, unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback); llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
llvm::ArrayRef<unsigned> lanes,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir } // namespace onnx_mlir
+64 -16
View File
@@ -41,15 +41,23 @@ using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm; using namespace onnx_mlir::compact_asm;
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (elementType.isIndex())
return sizeof(int64_t);
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() / 8;
llvm_unreachable("unsupported shaped element type");
}
static size_t getValueSizeInBytes(mlir::Value value) { static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8; return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
} }
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape()); assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8; size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
MemEntry memEntry = {0, allocSize}; MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first; return &memEntries.emplace_back(memEntry, value).first;
} }
@@ -398,20 +406,28 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
} }
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
emitMemCopyOp("ld", emitMemCopyOp("ld",
addressOf(loadOp.getDeviceTarget(), knowledge), addressOf(loadOp.getDeviceTarget(), knowledge),
loadOp.getDeviceTargetOffset(), *deviceTargetOffset,
addressOf(loadOp.getHostSource(), knowledge), addressOf(loadOp.getHostSource(), knowledge),
loadOp.getHostSourceOffset(), *hostSourceOffset,
loadOp.getSize()); loadOp.getSize());
} }
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
emitMemCopyOp("st", emitMemCopyOp("st",
addressOf(storeOp.getHostTarget(), knowledge), addressOf(storeOp.getHostTarget(), knowledge),
storeOp.getHostTargetOffset(), *hostTargetOffset,
addressOf(storeOp.getDeviceSource(), knowledge), addressOf(storeOp.getDeviceSource(), knowledge),
storeOp.getDeviceSourceOffset(), *deviceSourceOffset,
storeOp.getSize()); storeOp.getSize());
} }
@@ -426,8 +442,9 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg
} }
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp( auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize()); assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
} }
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
@@ -439,7 +456,9 @@ void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
} }
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize()); auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
} }
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
@@ -728,12 +747,19 @@ std::string getMemorySizeAsString(size_t size) {
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) { static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices; SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) { auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
auto addWeight = [&](mlir::Value weight) {
if (!coreOp)
return;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
if (coreOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex)) if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex); indices.push_back(weightIndex);
return;
}
}; };
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices); llvm::sort(indices);
return indices; return indices;
} }
@@ -795,6 +821,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
/// fully resolved before the JSON instructions are emitted. /// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure. /// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
return std::nullopt;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
return weightIndex;
return std::nullopt;
};
size_t processedOperations = 0; size_t processedOperations = 0;
auto result = auto result =
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) { walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
@@ -814,8 +849,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge); coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge); coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge); auto weightIndex = resolveWeightIndex(vmmOp);
if (!weightIndex)
return failure();
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
}
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op)) else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge); coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op)) else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
@@ -1004,10 +1043,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
reportedCoreIds.reserve(batchCoreIds.size()); reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow; MemoryReportRow batchRow;
std::optional<MemoryReportRow> batchPerCoreRow; std::optional<MemoryReportRow> batchPerCoreRow;
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
SmallVector<size_t> orderedOriginalCoreIds;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) { for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]); size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
if (inserted)
orderedOriginalCoreIds.push_back(originalCoreId);
it->second.push_back(lane);
}
for (size_t originalCoreId : orderedOriginalCoreIds) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
size_t coreId = emittedCoreIds.lookup(originalCoreId); size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId)); reportedCoreIds.push_back(static_cast<int32_t>(coreId));
MemoryReportRow laneRow; MemoryReportRow laneRow;
+10 -2
View File
@@ -128,12 +128,20 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) { SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices; SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) { auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
auto addWeight = [&](mlir::Value weight) {
if (!coreOp)
return;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
if (coreOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex)) if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex); indices.push_back(weightIndex);
return;
}
}; };
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices); llvm::sort(indices);
return indices; return indices;
} }
@@ -18,13 +18,17 @@ namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
}
template <typename Fn, size_t... Is> template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) { decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...); return std::forward<Fn>(fn)(block->getArgument(Is)...);
} }
template <typename Fn, size_t... Is> template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) { decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...); return std::forward<Fn>(fn)(values[Is]...);
} }
@@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block(); auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs) for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc); block->addArgument(input.getType(), loc);
@@ -93,14 +99,15 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>; using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) { if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {}); detail::invokeWithValues(
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return computeOp; return computeOp;
} }
else { else {
auto bodyResult = auto bodyResult = detail::invokeWithValues(
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {}); std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
@@ -123,6 +130,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block(); auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs) for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc); block->addArgument(input.getType(), loc);
@@ -131,13 +140,13 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>; using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) { if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block)); std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return computeOp; return computeOp;
} }
else { else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block)); auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
@@ -44,7 +44,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
IRMapping mapper; IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>()); SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty()) SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
if (!computes.empty() || !computeBatches.empty())
return; return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
@@ -190,16 +191,6 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect, tensor::TensorDialect,
arith::ArithDialect, arith::ArithDialect,
scf::SCFDialect>(); scf::SCFDialect>();
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
signalPassFailure();
return;
}
PassManager cleanupPM(ctx); PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass()); cleanupPM.addPass(createCanonicalizerPass());
@@ -402,13 +402,29 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute( auto computeOp =
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult { spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
for (Value weight : weights) {
blockArgTypes.push_back(weight.getType());
blockArgLocs.push_back(gemmLoc);
}
for (Value input : aHSlices[coreId]) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(gemmLoc);
}
Block* body =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size()); vmmOutputs.reserve(aHSlices[coreId].size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
vmmOutputs.push_back( vmmOutputs.push_back(spatial::SpatVMMOp::create(
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
if (vmmOutputs.empty()) { if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure(); return failure();
@@ -416,10 +432,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
Value partialVmmSum = sumTensors(vmmOutputs, rewriter); Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
return success(); rewriter.setInsertionPointAfter(computeOp);
});
if (failed(computeOp))
return failure();
partialResults.push_back(computeOp->getResult(0)); partialResults.push_back(computeOp->getResult(0));
} }
@@ -530,37 +543,47 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
sharedBias = c; sharedBias = c;
} }
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType); auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
auto batchOp = spatial::SpatComputeBatch::create(rewriter, auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc, loc,
TypeRange(resultTypes), TypeRange {outType},
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)), rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
ValueRange(weights), ValueRange {b},
ValueRange(aSlices)); ValueRange {a});
Block* body = rewriter.createBlock( SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc)); SmallVector<Location> blockArgLocs(4, loc);
Block* body =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body); rewriter.setInsertionPointToEnd(body);
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult(); Value lane = batchOp.getLaneArgument();
Value weight = batchOp.getWeightArgument(0);
Value packedInput = batchOp.getInputArgument(0);
Value packedOutput = batchOp.getOutputArgument(0);
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value row =
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
.getResult();
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
Value laneResult = vmmResult; Value laneResult = vmmResult;
if (sharedBias) if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
unitStrides);
rewriter.setInsertionPointAfter(batchOp); rewriter.setInsertionPointAfter(batchOp);
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
});
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, batchOp.getResults());
return success(); return success();
} }
@@ -35,58 +35,15 @@ template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
Block& block = compute.getBody().front(); Block& block = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= block.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input)) if (!isWeightLikeComputeOperand(input))
continue; continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx))) if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue; continue;
return true; return true;
} }
return false; return false;
} }
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
if (batchOp.getLaneCount() != 1)
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(templateBlock.getNumArguments());
blockArgLocs.reserve(templateBlock.getNumArguments());
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp->replaceAllUsesWith(computeOp->getResults());
rewriter.eraseOp(batchOp);
return success();
}
};
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs. // Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> { struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern; using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
@@ -96,11 +53,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
bool needsRewrite = false; bool needsRewrite = false;
Block& oldBlock = compute.getBody().front(); Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input)) if (!isWeightLikeComputeOperand(input))
continue; continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx))) if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue; continue;
promoteInput[inputIdx] = true; promoteInput[inputIdx] = true;
needsRewrite = true; needsRewrite = true;
@@ -131,8 +86,16 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
auto newCompute = auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
for (Value weight : newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
auto* newBlock = auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())}); {static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
@@ -141,14 +104,17 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
bodyRewriter.setInsertionPointToStart(newBlock); bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper; IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
size_t newInputIdx = 0; size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) { for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
if (!promoteInput[oldInputIdx]) { if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++)); mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
continue; continue;
} }
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper); auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue)) if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand"); return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue); mapper.map(oldArg, *clonedValue);
@@ -180,11 +146,9 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
bool needsRewrite = false; bool needsRewrite = false;
Block& oldBlock = compute.getBody().front(); Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input)) if (!isWeightLikeComputeOperand(input))
continue; continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx))) if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue; continue;
promoteInput[inputIdx] = true; promoteInput[inputIdx] = true;
needsRewrite = true; needsRewrite = true;
@@ -220,8 +184,25 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())), rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights, newWeights,
newInputs); newInputs);
auto* newBlock = SmallVector<Type> newBlockArgTypes;
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(compute.getLaneArgument().getType());
newBlockArgLocs.push_back(compute.getLaneArgument().getLoc());
for (Value weight : newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
newBlockArgTypes.push_back(resultType);
newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc());
}
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())}); {static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
@@ -230,31 +211,28 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
bodyRewriter.setInsertionPointToStart(newBlock); bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper; IRMapping mapper;
mapper.map(compute.getLaneArgument(), newCompute.getLaneArgument());
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
size_t newInputIdx = 0; size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) { for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
if (!promoteInput[oldInputIdx]) { if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++)); mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
continue; continue;
} }
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper); auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue)) if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand"); return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue); mapper.map(oldArg, *clonedValue);
} }
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
for (Operation& op : oldBlock.without_terminator()) for (Operation& op : oldBlock)
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults()); rewriter.replaceOp(compute, newCompute.getResults());
return success(); return success();
} }
@@ -262,10 +240,6 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
} // namespace } // namespace
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx); patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
} }
@@ -277,8 +251,6 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
}); });
} }
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); } bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); } bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
@@ -7,14 +7,10 @@
namespace onnx_mlir { namespace onnx_mlir {
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
bool requiresPostRewrite(spatial::SpatCompute computeOp); bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp); bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp); void annotateWeightsConstants(mlir::func::FuncOp funcOp);
@@ -2,6 +2,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -17,6 +18,37 @@ namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); } static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static FailureOr<int32_t> getConstantI32Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
FailureOr<int32_t> constantValue = getConstantI32Value(value);
if (failed(constantValue))
return failure();
constants.push_back(*constantValue);
}
return constants;
}
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static bool isUsedOnlyAsExplicitHostOperand(Value value) {
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
return isExplicitHostOperand(use.getOwner(), use.getOperandNumber());
});
}
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
@@ -28,27 +60,30 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
return coreIds; return coreIds;
} }
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp, static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper, IRMapping& mapper,
IRRewriter& rewriter) { IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds; FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size()); if (failed(targetCoreIds))
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds()) return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); for (int32_t& targetCoreId : *targetCoreIds)
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
pim::PimSendTensorBatchOp::create(rewriter, pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(), sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()), mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.getDenseI32ArrayAttr(*targetCoreIds));
return success();
} }
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp, static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper, IRMapping& mapper,
IRRewriter& rewriter) { IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds; FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size()); if (failed(sourceCoreIds))
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds()) return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType()); auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
@@ -56,24 +91,26 @@ static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatc
receiveTensorBatchOp.getLoc(), receiveTensorBatchOp.getLoc(),
outputBuffer.getType(), outputBuffer.getType(),
outputBuffer, outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds)) rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput(); .getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received); mapper.map(receiveTensorBatchOp.getOutput(), received);
return success();
} }
} // namespace } // namespace
LogicalResult LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) { lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
Location loc = computeBatchOp.getLoc(); Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front(); Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator()); if (computeBatchOp.getNumResults() != 0)
if (oldYield.getNumOperands() != 0) return computeBatchOp.emitOpError(
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield"); "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; "
"materialize explicit communication before lowering to PIM");
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId); SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
@@ -102,7 +139,12 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
IRMapping mapper; IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) { mapper.map(computeBatchOp.getLaneArgument(), coreBatchOp.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex)
mapper.map(computeBatchOp.getWeightArgument(weightIndex), coreBatchOp.getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
BlockArgument oldArg = computeBatchOp.getInputArgument(inputIndex);
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
auto newArgType = cast<ShapedType>(newArg.getType()); auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
@@ -142,20 +184,31 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
continue; continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) { if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
if (failed(targetCoreIds))
return sendBatchOp.emitOpError("expected constant targetCoreIds");
for (int32_t& targetCoreId : *targetCoreIds)
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
pim::PimSendBatchOp::create(rewriter, pim::PimSendBatchOp::create(rewriter,
loc, loc,
mapper.lookup(sendBatchOp.getInput()), mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())), getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr()); rewriter.getDenseI32ArrayAttr(*targetCoreIds));
continue; continue;
} }
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) { if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter); if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
return failure();
continue; continue;
} }
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) { if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
if (failed(sourceCoreIds))
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType()); auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter, auto received = pim::PimReceiveBatchOp::create(rewriter,
@@ -163,14 +216,15 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
outputBuffer.getType(), outputBuffer.getType(),
outputBuffer, outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()), getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr()) rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput(); .getOutput();
mapper.map(receiveBatchOp.getOutput(), received); mapper.map(receiveBatchOp.getOutput(), received);
continue; continue;
} }
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) { if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter); if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
return failure();
continue; continue;
} }
@@ -178,6 +232,10 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) { if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper); Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0); auto clonedTensor = cloned->getResult(0);
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
mapper.map(toTensorOp.getResult(), clonedTensor);
continue;
}
auto clonedType = cast<ShapedType>(clonedTensor.getType()); auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
@@ -194,9 +252,11 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
} }
} }
for (Value operand : op.getOperands()) { for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand)) if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue; continue;
if (isExplicitHostOperand(&op, operandIndex))
continue;
Operation* definingOp = operand.getDefiningOp(); Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock) if (definingOp && definingOp->getBlock() == &oldBlock)
@@ -22,6 +22,8 @@ add_pim_library(OMSpatialToPim
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRSCFDialect MLIRSCFDialect
MLIRSCFUtils
MLIRTransformUtils
MLIRTosaDialect MLIRTosaDialect
OMCompilerOptions OMCompilerOptions
OMPimCommon OMPimCommon
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
@@ -12,15 +13,24 @@ namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; } static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
}
return constants;
}
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> { struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create(rewriter, pim::PimSendOp::create(
op.getLoc(), rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
op.getInput(),
getTensorSizeInBytesAttr(rewriter, op.getInput()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
@@ -42,7 +52,7 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
op.getResult().getType(), op.getResult().getType(),
outputBuffer, outputBuffer,
getTensorSizeInBytesAttr(rewriter, op.getResult()), getTensorSizeInBytesAttr(rewriter, op.getResult()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId()))) op.getSourceCoreId())
.getOutput(); .getOutput();
rewriter.replaceOp(op, received); rewriter.replaceOp(op, received);
return success(); return success();
@@ -53,11 +63,12 @@ struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTens
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> targetCoreIds; FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
targetCoreIds.reserve(op.getTargetCoreIds().size()); if (failed(targetCoreIds))
for (int32_t targetCoreId : op.getTargetCoreIds()) return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
targetCoreIds.push_back(toPimCoreId(targetCoreId)); for (int32_t& targetCoreId : *targetCoreIds)
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds)); targetCoreId = toPimCoreId(targetCoreId);
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
@@ -67,16 +78,17 @@ struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelRecei
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> sourceCoreIds; FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
sourceCoreIds.reserve(op.getSourceCoreIds().size()); if (failed(sourceCoreIds))
for (int32_t sourceCoreId : op.getSourceCoreIds()) return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
sourceCoreIds.push_back(toPimCoreId(sourceCoreId)); for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = toPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(op.getOutput().getType()); auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer = Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received = Value received =
pim::PimReceiveTensorOp::create( pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds)) rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput(); .getOutput();
rewriter.replaceOp(op, received); rewriter.replaceOp(op, received);
return success(); return success();
@@ -29,7 +29,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
unsigned inputIndex, unsigned inputIndex,
Value replacement) { Value replacement) {
Block& body = owner->getRegion(0).front(); Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = body.getArgument(inputIndex); BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner)
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
unsigned bodyArgIndex = bodyArgument.getArgNumber();
rewriter.startOpModification(owner); rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement); bodyArgument.replaceAllUsesWith(replacement);
@@ -37,7 +40,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
compute.getInputsMutable().erase(inputIndex); compute.getInputsMutable().erase(inputIndex);
else else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex); cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(inputIndex); body.eraseArgument(bodyArgIndex);
rewriter.finalizeOpModification(owner); rewriter.finalizeOpModification(owner);
} }
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -27,7 +28,8 @@ static bool isChannelUseChainOp(Operation* op) {
pim::PimTransposeOp>(op); pim::PimTransposeOp>(op);
} }
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) { for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand)) if (mapping.lookupOrNull(operand))
continue; continue;
@@ -36,7 +38,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
if (!definingOp) if (!definingOp)
continue; continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp)) if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
continue;
}
if (!isa<tensor::EmptyOp>(definingOp))
continue; continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping); Operation* clonedOp = rewriter.clone(*definingOp, mapping);
@@ -48,6 +55,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); } static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
}
return constants;
}
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt()); return static_cast<int32_t>(spatialCoreIdAttr.getInt());
@@ -92,7 +111,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
return success(); return success();
} }
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false; return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
@@ -101,7 +122,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
return false; return false;
Block& block = computeOp.getBody().front(); Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0) if (block.getNumArguments() != computeOp.getWeights().size())
return false; return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator()); auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
@@ -110,8 +131,10 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
IRMapping mapping; IRMapping mapping;
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights()))
mapping.map(computeOp.getWeightArgument(weightIndex), weight);
for (Operation& op : block.without_terminator()) { for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter); cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(op, mapping); Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult); mapping.map(originalResult, newResult);
@@ -133,7 +156,7 @@ void markOpToRemove(CoreLoweringState& state, Operation* op) {
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) { LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc(); Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter)) if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder))
return success(); return success();
SmallVector<Operation*> helperChain; SmallVector<Operation*> helperChain;
@@ -143,21 +166,42 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
auto& block = computeOp.getRegion().front(); auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator()); auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) { for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp()); BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
if (!receiveOp || blockArg.use_empty()) auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
continue; if (receiveOp && !blockArg.use_empty()) {
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType()); auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); Value received =
Value received = PimReceiveOp::create( PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
.getOutput(); .getOutput();
blockArg.replaceAllUsesWith(received); blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp); markOpToRemove(state, receiveOp);
continue;
}
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
if (receiveTensorOp && !blockArg.use_empty()) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
if (failed(sourceCoreIds))
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
Value received = PimReceiveTensorOp::create(rewriter,
receiveTensorOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveTensorOp);
}
} }
if (computeOp.getNumResults() != yieldOp.getNumOperands()) if (computeOp.getNumResults() != yieldOp.getNumOperands())
@@ -197,11 +241,36 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
loc, loc,
ValueRange(computeWeights), ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId))); rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
rewriter.setInsertionPointToStart(&block);
auto& coreOpBlocks = coreOp.getBody().getBlocks(); auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
if (!blockArg.use_empty()) BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]); if (blockArg.use_empty())
block.eraseArguments(0, block.getNumArguments()); continue;
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder));
continue;
}
auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType)
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
auto copied =
PimMemCopyHostToDevOp::create(rewriter,
loc,
outputBuffer.getType(),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
outputBuffer,
input,
getTensorSizeInBytesAttr(rewriter, input))
.getOutput();
blockArg.replaceAllUsesWith(copied);
}
if (!computeOp.getInputs().empty())
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block(); Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock); computeOp.getBody().push_back(tempComputeBlock);
@@ -1,6 +1,7 @@
#pragma once #pragma once
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -11,6 +12,7 @@ struct CoreLoweringState {
size_t& nextCoreId; size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors; llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove; llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
mlir::OperationFolder& constantFolder;
}; };
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op); void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
@@ -76,8 +76,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty()) if (BBArgValue.use_empty())
continue; continue;
@@ -89,14 +88,13 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
} }
replaceAndEraseDirectComputeLikeInput( replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]); rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
} }
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) { else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty()) if (BBArgValue.use_empty())
continue; continue;
@@ -108,7 +106,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
} }
replaceAndEraseDirectComputeLikeInput( replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]); rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
} }
else { else {
{ {
@@ -254,7 +252,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
} }
} }
else if (constantOp.getType().isIntOrIndexOrFloat()) { else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst; Value hostConstant = constantOp.getResult();
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner(); auto constUsers = constUses.getOwner();
@@ -264,40 +262,22 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
} }
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) { else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber()); auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex) if (!inputIndex)
return failure(); return failure();
auto BBArgIndex = *inputIndex; auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
} }
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) { else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) { constUses.set(hostConstant);
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
} }
else { else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>(); auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute"); assert(batchParent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) { constUses.set(hostConstant);
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
} }
} }
} }
@@ -6,8 +6,10 @@
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/FoldUtils.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -318,7 +320,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
return success(); return success();
} }
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) { for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand)) if (mapping.lookupOrNull(operand))
continue; continue;
@@ -327,7 +330,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
if (!definingOp) if (!definingOp)
continue; continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp)) if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
continue;
}
if (!isa<tensor::EmptyOp>(definingOp))
continue; continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping); Operation* clonedOp = rewriter.clone(*definingOp, mapping);
@@ -337,15 +345,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
} }
} }
static void static void cloneHelperChain(Value sourceValue,
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) { ArrayRef<Operation*> helperChain,
IRRewriter& rewriter,
OperationFolder& constantFolder,
Value& clonedValue) {
IRMapping mapping; IRMapping mapping;
mapping.map(sourceValue, sourceValue); mapping.map(sourceValue, sourceValue);
clonedValue = sourceValue; clonedValue = sourceValue;
rewriter.setInsertionPointAfterValue(sourceValue); rewriter.setInsertionPointAfterValue(sourceValue);
for (Operation* op : helperChain) { for (Operation* op : helperChain) {
cloneMappedHelperOperands(op, mapping, rewriter); cloneMappedHelperOperands(op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(*op, mapping); Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult); mapping.map(originalResult, newResult);
@@ -360,14 +371,19 @@ static Value emitHostCopy(IRRewriter& rewriter,
Value sourceValue, Value sourceValue,
int32_t hostTargetOffset, int32_t hostTargetOffset,
int32_t deviceSourceOffset, int32_t deviceSourceOffset,
int32_t sizeInBytes) { int32_t sizeInBytes,
OperationFolder& constantFolder) {
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder);
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder);
return PimMemCopyDevToHostOp::create(rewriter, return PimMemCopyDevToHostOp::create(rewriter,
loc, loc,
outputTensor.getType(), outputTensor.getType(),
hostTargetOffsetValue,
deviceSourceOffsetValue,
outputTensor, outputTensor,
sourceValue, sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes)) rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput(); .getOutput();
} }
@@ -411,69 +427,84 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
} }
} }
ReturnPathLoweringResult lowerComputeResultReturnPath( ReturnPathLoweringResult lowerProducedValueReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) { Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc(); Location loc = producerOp->getLoc();
auto yieldType = cast<TensorType>(yieldValue.getType()); OperationFolder constantFolder(producerOp->getContext());
auto storedTensorType = cast<TensorType>(storedValue.getType());
if (auto returnUse = analyzeReturnUse(result)) { if (auto returnUse = analyzeReturnUse(producedValue)) {
Value storedValue = yieldValue; Value currentStoredValue = storedValue;
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue); cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
for (Operation* op : returnUse->helperChain) for (Operation* op : returnUse->helperChain)
markOpToRemove(state, op); markOpToRemove(state, op);
auto storedType = cast<ShapedType>(storedValue.getType()); auto storedType = cast<ShapedType>(currentStoredValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8; size_t elementSize = storedType.getElementTypeBitWidth() / 8;
if (auto storedOp = storedValue.getDefiningOp()) if (auto storedOp = currentStoredValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp); rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc); Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy( emitHostCopy(rewriter,
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize)); loc,
outputTensor,
currentStoredValue,
0,
0,
static_cast<int32_t>(storedType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
} }
auto resultUses = result.getUses(); auto resultUses = producedValue.getUses();
if (rangeLength(resultUses) == 1) { if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin(); OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner(); Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) { if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc); Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy( emitHostCopy(rewriter,
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize)); loc,
outputTensor,
storedValue,
0,
0,
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
} }
} }
if (auto concatReturnUse = analyzeConcatReturnUse(result)) { if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8; size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
for (Operation* concatOp : concatReturnUse->concatChain) for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(state, concatOp); markOpToRemove(state, concatOp);
if (concatReturnUse->helperChain.empty()) { if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType()); auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter, emitHostCopy(rewriter,
loc, loc,
outputTensor, outputTensor,
yieldValue, storedValue,
static_cast<int32_t>(flatOffset * elementSize), static_cast<int32_t>(flatOffset * elementSize),
0, 0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize)); static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
} }
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType()); auto storedType = dyn_cast<RankedTensorType>(storedValue.getType());
if (!storedType) { if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); producerOp->emitOpError(
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure; return ReturnPathLoweringResult::Failure;
} }
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType()); auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
@@ -484,7 +515,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
SmallVector<int64_t> destinationIndices; SmallVector<int64_t> destinationIndices;
if (failed(mapIndicesThroughHelperChain( if (failed(mapIndicesThroughHelperChain(
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) { sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure; return ReturnPathLoweringResult::Failure;
} }
@@ -503,7 +534,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
auto scalarTensorType = auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType()); RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
auto elementSlice = tensor::ExtractSliceOp::create( auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides); rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice); rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
@@ -513,7 +544,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
elementSlice.getResult(), elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize), static_cast<int32_t>(destinationFlatOffset * elementSize),
0, 0,
static_cast<int32_t>(elementSize)); static_cast<int32_t>(elementSize),
constantFolder);
} }
return ReturnPathLoweringResult::Handled; return ReturnPathLoweringResult::Handled;
} }
@@ -521,6 +553,11 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
return ReturnPathLoweringResult::NotReturnPath; return ReturnPathLoweringResult::NotReturnPath;
} }
ReturnPathLoweringResult lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter);
}
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) { void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op) if (!op)
@@ -569,7 +606,16 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
markOpToRemove(state, concatOp); markOpToRemove(state, concatOp);
for (Value operand : concatOp.getInputs()) for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
} }
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
markOpToRemove(state, receiveOp);
return;
}
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
markOpToRemove(state, receiveTensorOp);
}; };
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
@@ -32,6 +32,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute compu
ReturnPathState& state, ReturnPathState& state,
mlir::IRRewriter& rewriter); mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
mlir::Value producedValue,
mlir::Value storedValue,
ReturnPathState& state,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state); void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -16,8 +16,8 @@ def onnxToPimTranspose : Pat<
>; >;
def spatToPimVMM : Pat< def spatToPimVMM : Pat<
(SpatVMMOp:$srcOpRes $weightIndex, $vector), (SpatVMMOp:$srcOpRes $weight, $vector),
(PimVMMOp $weightIndex, $vector, (PimVMMOp $weight, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>; >;
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@@ -12,6 +13,8 @@
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
@@ -104,23 +107,34 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
IntegerAttr {}); IntegerAttr {});
} }
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { static Value createZeroedDeviceHVector(IRRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
OperationFolder& constantFolder) {
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroAttr = rewriter.getI32IntegerAttr(0); auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType))); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
if (outputBuffer->getParentOfType<PimCoreBatchOp>()) if (outputBuffer->getParentOfType<PimCoreBatchOp>())
return PimMemCopyHostToDevBatchOp::create( return PimMemCopyHostToDevBatchOp::create(rewriter,
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr) loc,
tensorType,
outputBuffer,
zeroValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
sizeAttr)
.getOutput(); .getOutput();
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr) return PimMemCopyHostToDevOp::create(
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
.getOutput(); .getOutput();
} }
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) { static Value
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
auto vectorType = cast<RankedTensorType>(vector.getType()); auto vectorType = cast<RankedTensorType>(vector.getType());
ArrayRef<int64_t> shape = vectorType.getShape(); ArrayRef<int64_t> shape = vectorType.getShape();
assert(isHVectorShape(shape) && "expected a horizontal vector"); assert(isHVectorShape(shape) && "expected a horizontal vector");
@@ -131,7 +145,7 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
auto paddedType = RankedTensorType::get( auto paddedType = RankedTensorType::get(
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding()); {shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType); Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
auto zeroAttr = rewriter.getI32IntegerAttr(0); auto zeroAttr = rewriter.getI32IntegerAttr(0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType))); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput(); return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
@@ -151,6 +165,7 @@ void SpatialToPimPass::runOnOperation() {
func::FuncOp funcOp = *entryFunc; func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
OperationFolder constantFolder(&getContext());
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect, target.addLegalDialect<PimDialect,
@@ -181,16 +196,17 @@ void SpatialToPimPass::runOnOperation() {
RewritePatternSet globalTensorPatterns(ctx); RewritePatternSet globalTensorPatterns(ctx);
populateGlobalTensorMaterializationPatterns(globalTensorPatterns); populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors); addReturnOutputBuffers(returnOp, rewriter, outputTensors);
ReturnPathState returnPathState {outputTensors, operationsToRemove};
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering"); funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
signalPassFailure(); signalPassFailure();
return; return;
} }
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove}; CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) { for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp); markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) { if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
@@ -251,7 +267,6 @@ void SpatialToPimPass::runOnOperation() {
} }
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
ReturnPathState returnPathState {outputTensors, operationsToRemove};
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState); replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
@@ -302,6 +317,7 @@ void SpatialToPimPass::runOnOperation() {
} }
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext());
funcOp.walk([&](PimVMMOp vmmOp) { funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType()); auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape(); ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -309,7 +325,7 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar"); assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
rewriter.setInsertionPoint(vmmOp); rewriter.setInsertionPoint(vmmOp);
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput()); Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
auto paddedOutputType = RankedTensorType::get( auto paddedOutputType = RankedTensorType::get(
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding()); {outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize) Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
@@ -336,10 +352,13 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
OperationFolder constantFolder(funcOp.getContext());
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType()); auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType(); Type elementType = tensorType.getElementType();
if (!elementType.isIntOrFloat())
return;
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
@@ -349,10 +368,11 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
rewriter, rewriter,
loc, loc,
tensorType, tensorType,
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
getOrCreateHostIndexConstant(
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
deviceTensor, deviceTensor,
inputTensor, inputTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize))); rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp}); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
+34 -16
View File
@@ -2,6 +2,7 @@
#define PIM_DIALECT_H #define PIM_DIALECT_H
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/AttrTypeBase.td" include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -24,7 +25,8 @@ def PimTensor :
// Execution // Execution
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> { def PimCoreOp : PimOp<"core", [SingleBlock,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Execute a block on a PIM core"; let summary = "Execute a block on a PIM core";
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
@@ -34,12 +36,16 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
I32Attr:$coreId I32Attr:$coreId
); );
let assemblyFormat = [{ let extraClassDeclaration = [{
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)` ::mlir::BlockArgument getWeightArgument(unsigned idx);
}]; }];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
} }
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> { def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Execute equivalent batched core bodies"; let summary = "Execute equivalent batched core bodies";
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
@@ -50,6 +56,13 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi
Variadic<PimTensor>:$inputs Variadic<PimTensor>:$inputs
); );
let extraClassDeclaration = [{
::mlir::BlockArgument getLaneArgument();
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
@@ -81,11 +94,11 @@ def PimSendOp : PimOp<"send", []> {
let arguments = (ins let arguments = (ins
PimTensor:$input, PimTensor:$input,
I32Attr:$size, I32Attr:$size,
I32Attr:$targetCoreId Index:$targetCoreId
); );
let assemblyFormat = [{ let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` `(` `)` `(` $input `,` $targetCoreId `)` attr-dict `:` type($input) `->` `(` `)`
}]; }];
} }
@@ -131,7 +144,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let arguments = (ins let arguments = (ins
PimTensor:$outputBuffer, PimTensor:$outputBuffer,
I32Attr:$size, I32Attr:$size,
I32Attr:$sourceCoreId Index:$sourceCoreId
); );
let results = (outs let results = (outs
@@ -145,7 +158,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
}]; }];
let assemblyFormat = [{ let assemblyFormat = [{
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output) `(` $outputBuffer `,` $sourceCoreId `)` attr-dict `:` type($outputBuffer) `->` type($output)
}]; }];
} }
@@ -219,10 +232,10 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory"; let summary = "Copy a memory region from host memory into device memory";
let arguments = (ins let arguments = (ins
Index:$deviceTargetOffset,
Index:$hostSourceOffset,
PimTensor:$deviceTarget, PimTensor:$deviceTarget,
PimTensor:$hostSource, PimTensor:$hostSource,
I32Attr:$deviceTargetOffset,
I32Attr:$hostSourceOffset,
I32Attr:$size I32Attr:$size
); );
@@ -237,7 +250,9 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
}]; }];
let assemblyFormat = [{ let assemblyFormat = [{
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output) `[` $deviceTargetOffset `,` $hostSourceOffset `]`
`(` $deviceTarget `,` $hostSource `)` attr-dict
`:` type($deviceTarget) `,` type($hostSource) `->` type($output)
}]; }];
} }
@@ -271,10 +286,10 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from device memory into host memory"; let summary = "Copy a memory region from device memory into host memory";
let arguments = (ins let arguments = (ins
Index:$hostTargetOffset,
Index:$deviceSourceOffset,
PimTensor:$hostTarget, PimTensor:$hostTarget,
PimTensor:$deviceSource, PimTensor:$deviceSource,
I32Attr:$hostTargetOffset,
I32Attr:$deviceSourceOffset,
I32Attr:$size I32Attr:$size
); );
@@ -289,7 +304,9 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
}]; }];
let assemblyFormat = [{ let assemblyFormat = [{
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output) `[` $hostTargetOffset `,` $deviceSourceOffset `]`
`(` $hostTarget `,` $deviceSource `)` attr-dict
`:` type($hostTarget) `,` type($deviceSource) `->` type($output)
}]; }];
} }
@@ -374,7 +391,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
let summary = "Vector-matrix multiplication: c = a * b"; let summary = "Vector-matrix multiplication: c = a * b";
let arguments = (ins let arguments = (ins
I32Attr:$weightIndex, PimTensor:$weight,
PimTensor:$input, PimTensor:$input,
PimTensor:$outputBuffer PimTensor:$outputBuffer
); );
@@ -391,7 +408,8 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
let hasVerifier = 1; let hasVerifier = 1;
let assemblyFormat = [{ let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) `[` $weight `]` `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($weight) `,` type($input) `,`
type($outputBuffer) `)` `->` type($output)
}]; }];
} }
+33
View File
@@ -1,8 +1,41 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include <string>
using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
BlockArgument PimCoreOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
void PimCoreOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
}
BlockArgument PimCoreBatchOp::getLaneArgument() { return getBody().front().getArgument(0); }
BlockArgument PimCoreBatchOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
BlockArgument PimCoreBatchOp::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + idx);
}
void PimCoreBatchOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
setNameFn(getLaneArgument(), "lane");
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
}
void PimDialect::initialize() { void PimDialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
+180 -31
View File
@@ -20,6 +20,80 @@ static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int3
return parser.getBuilder().getDenseI32ArrayAttr(values); return parser.getBuilder().getDenseI32ArrayAttr(values);
} }
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
printer << "(";
for (auto [index, argument] : llvm::enumerate(arguments)) {
if (index != 0)
printer << ", ";
printer.printOperand(argument);
}
printer << ")";
}
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (parser.parseLParen())
return failure();
if (succeeded(parser.parseOptionalRParen()))
return success();
OpAsmParser::Argument argument;
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
}
return parser.parseRParen();
}
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
printer << " = ";
printCompressedValueList(printer, operands, delimiter);
}
static ParseResult parseBoundValueList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren:
return parser.parseRParen();
case ListDelimiter::Square:
return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) { static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
printer << " " << keyword << " "; printer << " " << keyword << " ";
printCompressedIntegerList(printer, coreIds); printCompressedIntegerList(printer, coreIds);
@@ -33,15 +107,76 @@ static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keywor
} // namespace } // namespace
void PimCoreBatchOp::print(OpAsmPrinter& printer) { void PimCoreOp::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " "; SmallVector<Value> weightArgs;
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0; weightArgs.reserve(getWeights().size());
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane)) for (unsigned index = 0; index < getWeights().size(); ++index)
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren); weightArgs.push_back(getWeightArgument(index));
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
printer << " "; printer << " ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Square); printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " coreId " << getCoreId();
printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " -> () ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<Type> weightTypes;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (hasCoreId && result.attributes.get("coreId"))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands))
return failure();
if (hasCoreId)
result.addAttribute("coreId", getI32Attr(parser, coreId));
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
return parser.parseRegion(*body, weightArgs);
}
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getLaneArgument());
printer << " = 0 to " << getLaneCount() << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef()); printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
@@ -49,51 +184,57 @@ void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer.printOptionalAttrDict( printer.printOptionalAttrDict(
(*this)->getAttrs(), (*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
printer << " : "; printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane)) printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
printer << " "; printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square); printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ()"; printer << " -> () ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
} }
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) { ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0; int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights; SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs; SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes; SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes; SmallVector<Type> inputTypes;
SmallVector<int32_t> coreIds; SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount) if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights) || parser.parseKeyword("to") || parser.parseInteger(laneCount))
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs)) return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)
|| parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure(); return failure();
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds")); bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure(); return failure();
if (parser.parseOptionalAttrDict(result.attributes)) if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
return failure(); || parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
Region* body = result.addRegion(); || parseCompressedRepeatedList(
if (parser.parseRegion(*body)) parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
return failure(); || parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|| parser.parseLParen() || parser.parseRParen())
return failure(); return failure();
if (weights.size() != weightTypes.size()) if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size()) if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(), return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict"); "coreIds cannot be specified both positionally and in attr-dict");
@@ -110,7 +251,15 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) { || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
return failure(); return failure();
} }
return success();
Region* body = result.addRegion();
laneArg.type = builder.getIndexType();
regionArgs.push_back(laneArg);
applyArgumentTypes(weightTypes, weightArgs);
llvm::append_range(regionArgs, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
} }
void PimYieldOp::print(OpAsmPrinter& printer) { void PimYieldOp::print(OpAsmPrinter& printer) {
+85 -15
View File
@@ -1,5 +1,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
@@ -14,6 +16,52 @@ namespace pim {
namespace { namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (isa<PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static Region* getParentRegion(Value value) {
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return blockArgument.getParentRegion();
Operation* definingOp = value.getDefiningOp();
return definingOp ? definingOp->getParentRegion() : nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|| isExplicitHostOperand(op, operand.getOperandNumber()))
continue;
InFlightDiagnostic diagnostic =
ownerOp->emitOpError() << kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static bool haveSameShapedContainerKind(Type lhs, Type rhs) { static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs)); return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
} }
@@ -78,24 +126,46 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
return success(); return success();
} }
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) { static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
if (auto coreOp = op->getParentOfType<PimCoreOp>()) { auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (weightIndex >= coreOp.getWeights().size()) if (!shapedType)
return failure();
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
}
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
if (weightIndex >= coreBatchOp.getWeights().size())
return failure();
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure(); return failure();
return shapedType.getShape();
} }
} // namespace } // namespace
LogicalResult PimCoreOp::verify() {
Block& block = getBody().front();
if (block.getNumArguments() != getWeights().size())
return emitError("core body must have one block argument per weight");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core weight block argument types must match weight operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
}
LogicalResult PimCoreBatchOp::verify() {
if (getLaneCount() <= 0)
return emitError("laneCount must be positive");
Block& block = getBody().front();
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("core_batch body must have lane, weight, and input block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("core_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("core_batch input block argument types must match input operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
}
LogicalResult PimSendTensorOp::verify() { LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor"); return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
} }
@@ -126,9 +196,9 @@ LogicalResult PimVMMOp::verify() {
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure(); return failure();
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex()); auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
if (failed(matrixShapeOpt)) if (failed(matrixShapeOpt))
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex"); return emitError("weight must be a shaped value");
ArrayRef<int64_t> matrixShape = *matrixShapeOpt; ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
auto vectorType = dyn_cast<ShapedType>(getInput().getType()); auto vectorType = dyn_cast<ShapedType>(getInput().getType());
@@ -38,10 +38,10 @@ struct MemCopyHostToDevOpInterface
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter, replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
memCopyHostToDevOp, memCopyHostToDevOp,
deviceTargetMemRef.getType(), deviceTargetMemRef.getType(),
memCopyHostToDevOp.getDeviceTargetOffset(),
memCopyHostToDevOp.getHostSourceOffset(),
deviceTargetMemRef, deviceTargetMemRef,
hostSourceMemRef, hostSourceMemRef,
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
memCopyHostToDevOp.getHostSourceOffsetAttr(),
memCopyHostToDevOp.getSizeAttr()); memCopyHostToDevOp.getSizeAttr());
return success(); return success();
} }
@@ -96,10 +96,10 @@ struct MemCopyDevToHostOpInterface
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter, replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
memCopyDevToHostOp, memCopyDevToHostOp,
hostTargetMemRef.getType(), hostTargetMemRef.getType(),
memCopyDevToHostOp.getHostTargetOffset(),
memCopyDevToHostOp.getDeviceSourceOffset(),
hostTargetMemRef, hostTargetMemRef,
deviceSourceMemRef, deviceSourceMemRef,
memCopyDevToHostOp.getHostTargetOffsetAttr(),
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
memCopyDevToHostOp.getSizeAttr()); memCopyDevToHostOp.getSizeAttr());
return success(); return success();
} }
@@ -151,12 +151,8 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
if (failed(outputBufferOpt)) if (failed(outputBufferOpt))
return failure(); return failure();
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter, replaceOpWithNewBufferizedOp<PimReceiveOp>(
op, rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdAttr());
return success(); return success();
} }
}; };
@@ -302,7 +298,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
op, op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(), sendOp.getSizeAttr(),
sendOp.getTargetCoreIdAttr()); sendOp.getTargetCoreId());
return success(); return success();
} }
}; };
@@ -368,6 +364,37 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
return {}; return {};
} }
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
auto coreOp = cast<PimCoreOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
return {};
unsigned weightIndex = bbArg.getArgNumber();
return {
{&coreOp->getOpOperand(weightIndex), BufferRelation::Equivalent}
};
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
FailureOr<BufferLikeType> getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
auto coreOp = cast<PimCoreOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
return failure();
Value tiedWeight = coreOp.getWeights()[bbArg.getArgNumber()];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedWeight.getType()))
return memRefType;
return bufferization::getBufferType(tiedWeight, options, state, invocationStack);
}
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
RewriterBase& rewriter, RewriterBase& rewriter,
const BufferizationOptions& options, const BufferizationOptions& options,
@@ -375,7 +402,10 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
auto coreOp = cast<PimCoreOp>(op); auto coreOp = cast<PimCoreOp>(op);
bool alreadyBufferized = bool alreadyBufferized =
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); }); llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
&& llvm::all_of(coreOp.getBody().front().getArguments(), [](BlockArgument arg) {
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
});
if (alreadyBufferized) if (alreadyBufferized)
return success(); return success();
@@ -420,9 +450,17 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front()) if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
return {}; return {};
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber(); unsigned argNumber = bbArg.getArgNumber();
if (argNumber == 0)
return {};
unsigned weightCount = coreBatchOp.getWeights().size();
unsigned operandIndex = argNumber - 1;
if (argNumber > weightCount + 1)
operandIndex = weightCount + (argNumber - 1 - weightCount);
return { return {
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent} {&coreBatchOp->getOpOperand(operandIndex), BufferRelation::Equivalent}
}; };
} }
@@ -438,11 +476,21 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front()) if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
return failure(); return failure();
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()]; unsigned argNumber = bbArg.getArgNumber();
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType())) if (argNumber == 0)
return failure();
Value tiedOperand;
unsigned weightCount = coreBatchOp.getWeights().size();
if (argNumber <= weightCount)
tiedOperand = coreBatchOp.getWeights()[argNumber - 1];
else
tiedOperand = coreBatchOp.getInputs()[argNumber - 1 - weightCount];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedOperand.getType()))
return memRefType; return memRefType;
return bufferization::getBufferType(tiedInput, options, state, invocationStack); return bufferization::getBufferType(tiedOperand, options, state, invocationStack);
} }
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
@@ -454,8 +502,9 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
bool alreadyBufferized = bool alreadyBufferized =
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); }) llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); }) && llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(), && llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); }); return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
});
if (alreadyBufferized) if (alreadyBufferized)
return success(); return success();
@@ -553,6 +602,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
BufferizationState& state) const { BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op); auto vmmOp = cast<PimVMMOp>(op);
auto weightOpt = getBufferOrValue(rewriter, vmmOp.getWeight(), options, state);
if (failed(weightOpt))
return failure();
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state); auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
if (failed(inputOpt)) if (failed(inputOpt))
return failure(); return failure();
@@ -564,7 +617,7 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter); Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>( replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt); rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt);
return success(); return success();
} }
}; };
+74 -16
View File
@@ -1,5 +1,6 @@
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
@@ -9,19 +10,62 @@ namespace onnx_mlir::spatial {
namespace { namespace {
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); } static FailureOr<int64_t> getConstantI64(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return constantValue.getSExtValue();
}
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); } static FailureOr<int32_t> getConstantI32(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
return getConstantI64(sendOp.getChannelId());
}
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
return getConstantI64(receiveOp.getChannelId());
}
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
return getConstantI32(receiveOp.getSourceCoreId());
}
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
return getConstantI32(receiveOp.getTargetCoreId());
}
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) { static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
if (!endpoints.send || !endpoints.receive) if (!endpoints.send || !endpoints.receive)
return failure(); return failure();
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) { FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
return failure();
}
if (*sendSourceCoreId != *receiveSourceCoreId) {
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive"); endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
return failure(); return failure();
} }
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
return failure();
}
if (*sendTargetCoreId != *receiveTargetCoreId) {
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive"); endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
return failure(); return failure();
} }
@@ -46,20 +90,26 @@ Channels::Channels(func::FuncOp funcOp) {
Channels::ChannelId Channels::allocate() { return nextChannelId++; } Channels::ChannelId Channels::allocate() { return nextChannelId++; }
void Channels::insertSend(SpatChannelSendOp sendOp) { void Channels::insertSend(SpatChannelSendOp sendOp) {
ChannelId channelId = getChannelId(sendOp); FailureOr<ChannelId> channelId = getChannelId(sendOp);
nextChannelId = std::max(nextChannelId, channelId + 1); if (failed(channelId))
endpoints[channelId].send = sendOp; return;
nextChannelId = std::max(nextChannelId, *channelId + 1);
endpoints[*channelId].send = sendOp;
} }
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) { void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
ChannelId channelId = getChannelId(receiveOp); FailureOr<ChannelId> channelId = getChannelId(receiveOp);
nextChannelId = std::max(nextChannelId, channelId + 1); if (failed(channelId))
endpoints[channelId].receive = receiveOp; return;
nextChannelId = std::max(nextChannelId, *channelId + 1);
endpoints[*channelId].receive = receiveOp;
} }
void Channels::eraseSend(SpatChannelSendOp sendOp) { void Channels::eraseSend(SpatChannelSendOp sendOp) {
ChannelId channelId = getChannelId(sendOp); FailureOr<ChannelId> channelId = getChannelId(sendOp);
auto it = endpoints.find(channelId); if (failed(channelId))
return;
auto it = endpoints.find(*channelId);
if (it == endpoints.end()) if (it == endpoints.end())
return; return;
it->second.send = {}; it->second.send = {};
@@ -68,8 +118,10 @@ void Channels::eraseSend(SpatChannelSendOp sendOp) {
} }
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) { void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
ChannelId channelId = getChannelId(receiveOp); FailureOr<ChannelId> channelId = getChannelId(receiveOp);
auto it = endpoints.find(channelId); if (failed(channelId))
return;
auto it = endpoints.find(*channelId);
if (it == endpoints.end()) if (it == endpoints.end())
return; return;
it->second.receive = {}; it->second.receive = {};
@@ -85,14 +137,20 @@ FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
} }
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const { FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
auto endpointsOr = lookup(getChannelId(sendOp)); FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return failure();
auto endpointsOr = lookup(*channelId);
if (failed(endpointsOr) || !endpointsOr->receive) if (failed(endpointsOr) || !endpointsOr->receive)
return failure(); return failure();
return endpointsOr->receive; return endpointsOr->receive;
} }
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const { FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
auto endpointsOr = lookup(getChannelId(receiveOp)); FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return failure();
auto endpointsOr = lookup(*channelId);
if (failed(endpointsOr) || !endpointsOr->send) if (failed(endpointsOr) || !endpointsOr->send)
return failure(); return failure();
return endpointsOr->send; return endpointsOr->send;
+100 -45
View File
@@ -2,8 +2,12 @@
#define SPATIAL_DIALECT_H #define SPATIAL_DIALECT_H
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/BuiltinTypes.td" include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/AttrTypeBase.td" include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def SpatialDialect : Dialect { def SpatialDialect : Dialect {
let name = "spat"; let name = "spat";
@@ -22,7 +26,9 @@ def SpatTensor :
// Execution // Execution
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { def SpatCompute : SpatOp<"compute",
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Compute region with attached constant weights"; let summary = "Compute region with attached constant weights";
let arguments = (ins let arguments = (ins
@@ -36,14 +42,20 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
}];
let hasVerifier = 1; let hasVerifier = 1;
let hasFolder = 1; let hasFolder = 1;
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
def SpatComputeBatch : SpatOp<"compute_batch", def SpatComputeBatch : SpatOp<"compute_batch",
[SingleBlock, AttrSizedOperandSegments]> { [SingleBlock, AttrSizedOperandSegments,
let summary = "Compressed batch of independent equivalent compute lanes"; DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
let arguments = (ins let arguments = (ins
I32Attr:$laneCount, I32Attr:$laneCount,
@@ -57,10 +69,41 @@ def SpatComputeBatch : SpatOp<"compute_batch",
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
::mlir::BlockArgument getLaneArgument();
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
::mlir::BlockArgument getOutputArgument(unsigned idx);
}];
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
} }
def SpatInParallelOp : SpatOp<"in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"SpatComputeBatch">,
] # GraphRegionNoTerminator.traits> {
let summary = "Parallel combining terminator for resultful spat.compute_batch";
let regions = (region SizedRegion<1>:$region);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins)>,
];
let extraClassDeclaration = [{
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
::mlir::OpResult getParentResult(int64_t idx);
}];
}
def SpatYieldOp : SpatOp<"yield", [Terminator]> { def SpatYieldOp : SpatOp<"yield", [Terminator]> {
let summary = "Yield results from a compute region"; let summary = "Yield results from a compute region";
@@ -110,14 +153,14 @@ def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a logical channel"; let summary = "Send a tensor through a logical channel";
let arguments = (ins let arguments = (ins
I64Attr:$channelId, Index:$channelId,
I32Attr:$sourceCoreId, Index:$sourceCoreId,
I32Attr:$targetCoreId, Index:$targetCoreId,
SpatTensor:$input SpatTensor:$input
); );
let assemblyFormat = [{ let assemblyFormat = [{
$input attr-dict `:` type($input) $input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
}]; }];
} }
@@ -125,9 +168,9 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a logical channel"; let summary = "Receive a tensor from a logical channel";
let arguments = (ins let arguments = (ins
I64Attr:$channelId, Index:$channelId,
I32Attr:$sourceCoreId, Index:$sourceCoreId,
I32Attr:$targetCoreId Index:$targetCoreId
); );
let results = (outs let results = (outs
@@ -135,31 +178,33 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
); );
let assemblyFormat = [{ let assemblyFormat = [{
attr-dict `:` type($output) `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
}]; }];
} }
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> { def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
let summary = "Send equal contiguous chunks of one tensor through logical channels"; let summary = "Send equal contiguous chunks of one tensor through logical channels";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, Variadic<Index>:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds, Variadic<Index>:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds, Variadic<Index>:$targetCoreIds,
SpatTensor:$input SpatTensor:$input
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1; let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
} }
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> { def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
let summary = "Receive equal contiguous chunks of one tensor from logical channels"; let summary = "Receive equal contiguous chunks of one tensor from logical channels";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, Variadic<Index>:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds, Variadic<Index>:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds Variadic<Index>:$targetCoreIds
); );
let results = (outs let results = (outs
@@ -167,44 +212,50 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1; let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
} }
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
let summary = "Send per-lane tensors through logical channels in a batch body"; let summary = "Send per-lane tensors through logical channels in a batch body";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, Variadic<Index>:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds, Variadic<Index>:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds, Variadic<Index>:$targetCoreIds,
SpatTensor:$input SpatTensor:$input
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1; let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
} }
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> { def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body"; let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, Variadic<Index>:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds, Variadic<Index>:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds, Variadic<Index>:$targetCoreIds,
SpatTensor:$input SpatTensor:$input
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1; let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
} }
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
let summary = "Receive a per-lane tensor through logical channels in a batch body"; let summary = "Receive a per-lane tensor through logical channels in a batch body";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, Variadic<Index>:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds, Variadic<Index>:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds Variadic<Index>:$targetCoreIds
); );
let results = (outs let results = (outs
@@ -212,16 +263,18 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1; let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
} }
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> { def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body"; let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins let arguments = (ins
DenseI64ArrayAttr:$channelIds, Variadic<Index>:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds, Variadic<Index>:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds Variadic<Index>:$targetCoreIds
); );
let results = (outs let results = (outs
@@ -229,7 +282,9 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
); );
let hasVerifier = 1; let hasVerifier = 1;
let hasCustomAssemblyFormat = 1; let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -240,7 +295,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation"; let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins let arguments = (ins
I32Attr:$weightIndex, SpatTensor:$weight,
SpatTensor:$input SpatTensor:$input
); );
@@ -251,7 +306,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
let hasVerifier = 1; let hasVerifier = 1;
let assemblyFormat = [{ let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output) `[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
}]; }];
} }
@@ -259,7 +314,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation"; let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins let arguments = (ins
I32Attr:$weightIndex, SpatTensor:$weight,
SpatTensor:$input SpatTensor:$input
); );
@@ -270,7 +325,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
let hasVerifier = 1; let hasVerifier = 1;
let assemblyFormat = [{ let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output) `[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
}]; }];
} }
+64
View File
@@ -1,10 +1,74 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include <string>
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
BlockArgument SpatCompute::getInputArgument(unsigned idx) {
return getBody().front().getArgument(getWeights().size() + idx);
}
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
}
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); }
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + idx);
}
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx);
}
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
setNameFn(getLaneArgument(), "lane");
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getNumResults(); ++index) {
if (index == 0) {
setNameFn(getOutputArgument(index), "out");
continue;
}
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str());
}
}
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
OpBuilder::InsertionGuard guard(builder);
Region* bodyRegion = result.addRegion();
builder.createBlock(bodyRegion);
}
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
return getRegion().front().getOperations();
}
void SpatialDialect::initialize() { void SpatialDialect::initialize() {
addTypes< addTypes<
#define GET_TYPEDEF_LIST #define GET_TYPEDEF_LIST
+2
View File
@@ -5,7 +5,9 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include <map> #include <map>
#include <string> #include <string>
+151 -217
View File
@@ -23,22 +23,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy)); return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
} }
static void printChannelMetadata(OpAsmPrinter& printer,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
printer << " channels ";
printCompressedIntegerList(printer, channelIds);
printer << " from ";
printCompressedIntegerList(printer, sourceCoreIds);
printer << " to ";
printCompressedIntegerList(printer, targetCoreIds);
}
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
return parser.getBuilder().getDenseI64ArrayAttr(values);
}
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) { static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
return parser.getBuilder().getDenseI32ArrayAttr(values); return parser.getBuilder().getDenseI32ArrayAttr(values);
} }
@@ -47,94 +31,89 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value); return parser.getBuilder().getI32IntegerAttr(value);
} }
template <typename TensorSendOpTy> static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) { printer << "(";
printer << " "; for (auto [index, argument] : llvm::enumerate(arguments)) {
printer.printOperand(op.getInput()); if (index != 0)
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds()); printer << ", ";
printer.printOptionalAttrDict(op->getAttrs(), printer.printOperand(argument);
{op.getChannelIdsAttrName().getValue(), }
op.getSourceCoreIdsAttrName().getValue(), printer << ")";
op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getInput().getType());
} }
template <typename TensorReceiveOpTy> static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) { if (parser.parseLParen())
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds()); return failure();
printer.printOptionalAttrDict(op->getAttrs(), if (succeeded(parser.parseOptionalRParen()))
{op.getChannelIdsAttrName().getValue(), return success();
op.getSourceCoreIdsAttrName().getValue(),
op.getTargetCoreIdsAttrName().getValue()}); OpAsmParser::Argument argument;
printer << " : "; if (parser.parseArgument(argument))
printer.printType(op.getOutput().getType()); return failure();
arguments.push_back(argument);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
}
return parser.parseRParen();
} }
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) { static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
OpAsmParser::UnresolvedOperand input; ArrayRef<Type> weightTypes,
Type inputType; ArrayRef<Type> outputTypes,
SmallVector<int64_t> channelIds; OpAsmParser::Argument& laneArg,
SmallVector<int32_t> sourceCoreIds; SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
SmallVector<int32_t> targetCoreIds; SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
if (parser.parseOperand(input)) SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
return failure(); Builder& builder) {
laneArg.type = builder.getIndexType();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); regionArgs.push_back(laneArg);
if (hasMetadata) { applyArgumentTypes(weightTypes, weightArgs);
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") llvm::append_range(regionArgs, weightArgs);
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") applyArgumentTypes(inputTypes, inputArgs);
|| parseCompressedIntegerList(parser, targetCoreIds)) applyArgumentTypes(outputTypes, outputArgs);
return failure(); llvm::append_range(regionArgs, inputArgs);
} llvm::append_range(regionArgs, outputArgs);
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
} }
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) { static void
Type outputType; printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
SmallVector<int64_t> channelIds; printCompressedValueList(printer, arguments, delimiter);
SmallVector<int32_t> sourceCoreIds; printer << " = ";
SmallVector<int32_t> targetCoreIds; printCompressedValueList(printer, operands, delimiter);
}
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); static ParseResult parseBoundValueList(OpAsmParser& parser,
if (hasMetadata) { ListDelimiter delimiter,
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") SmallVectorImpl<OpAsmParser::Argument>& arguments,
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|| parseCompressedIntegerList(parser, targetCoreIds)) if (parseOpenDelimiter(parser, delimiter))
return failure(); return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
} }
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)) if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren:
return parser.parseRParen();
case ListDelimiter::Square:
return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) {
return failure(); return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
} }
result.addTypes(outputType);
return success(); return success();
} }
@@ -243,9 +222,17 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
void SpatCompute::print(OpAsmPrinter& printer) { void SpatCompute::print(OpAsmPrinter& printer) {
printer << " "; printer << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square); SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " "; printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs()); SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt(); printer << " coreId " << coreIdAttr.getInt();
@@ -264,6 +251,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
} }
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs; SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights; SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs; SmallVector<OpAsmParser::UnresolvedOperand> inputs;
@@ -272,10 +260,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes; SmallVector<Type> outputTypes;
int32_t coreId = 0; int32_t coreId = 0;
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure(); return failure();
if (parseArgumentBindings(parser, regionArgs, inputs)) SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure(); return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
@@ -292,9 +281,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
if (weights.size() != weightTypes.size()) if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size()) if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (regionArgs.size() != inputs.size()) if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(), return parser.emitError(parser.getCurrentLocation(),
@@ -313,19 +304,39 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
result.addTypes(outputTypes); result.addTypes(outputTypes);
Region* body = result.addRegion(); Region* body = result.addRegion();
applyArgumentTypes(inputTypes, regionArgs); applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs); return parser.parseRegion(*body, regionArgs);
} }
void SpatComputeBatch::print(OpAsmPrinter& printer) { void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " "; printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs()); printer.printOperand(getLaneArgument());
printer << " = 0 to " << getLaneCount();
printer << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (getNumResults() != 0) {
printer << " shared_outs";
SmallVector<BlockArgument> outputArgs;
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index)
outputArgs.push_back(getOutputArgument(index));
printBlockArgumentList(printer, outputArgs);
}
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) { if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds "; printer << " coreIds ";
@@ -337,9 +348,6 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : "; printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " "; printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
@@ -350,7 +358,12 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
} }
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0; int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs; SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights; SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs; SmallVector<OpAsmParser::UnresolvedOperand> inputs;
@@ -359,13 +372,20 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
SmallVector<Type> outputTypes; SmallVector<Type> outputTypes;
SmallVector<int32_t> coreIds; SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure(); return failure();
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights)) if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure(); return failure();
if (parseArgumentBindings(parser, regionArgs, inputs)) if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure(); return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
@@ -381,10 +401,15 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
if (weights.size() != weightTypes.size()) if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size()) if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (regionArgs.size() != inputs.size()) if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(), return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict"); "coreIds cannot be specified both positionally and in attr-dict");
@@ -403,119 +428,28 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
result.addTypes(outputTypes); result.addTypes(outputTypes);
Region* body = result.addRegion(); Region* body = result.addRegion();
applyArgumentTypes(inputTypes, regionArgs); applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs); return parser.parseRegion(*body, regionArgs);
} }
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); } void SpatInParallelOp::print(OpAsmPrinter& printer) {
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
printer << " "; printer << " ";
printer.printOperand(getInput()); printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict((*this)->getAttrs());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
} }
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input; auto& builder = parser.getBuilder();
Type inputType; std::unique_ptr<Region> region = std::make_unique<Region>();
SmallVector<int64_t> channelIds; SmallVector<OpAsmParser::Argument, 4> regionArgs;
SmallVector<int32_t> sourceCoreIds; if (parser.parseRegion(*region, regionArgs))
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
return failure(); return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); if (region->empty())
if (hasMetadata) { OpBuilder(builder.getContext()).createBlock(region.get());
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") result.addRegion(std::move(region));
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") return parser.parseOptionalAttrDict(result.attributes);
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
}
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutput().getType());
}
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputType);
return success();
}
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
} }
} // namespace spatial } // namespace spatial
+227 -112
View File
@@ -1,6 +1,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
@@ -82,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
return success(); return success();
} }
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) { static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>()) auto shapedType = dyn_cast<ShapedType>(weight.getType());
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape(); if (!shapedType)
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
return failure();
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure(); return failure();
return shapedType.getShape();
} }
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) { static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
@@ -105,15 +99,86 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
return batchOp.getLaneCount(); return batchOp.getLaneCount();
} }
static LogicalResult verifyTensorChannelSizes(Operation* op, static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
Type type, if (batchOp.getNumResults() == 0)
ArrayRef<int64_t> channelIds, return false;
ArrayRef<int32_t> sourceCoreIds, auto blockArg = dyn_cast<BlockArgument>(value);
ArrayRef<int32_t> targetCoreIds, if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
return false;
unsigned argNumber = blockArg.getArgNumber();
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
}
static bool isConstantIndexLike(Value value) {
APInt constantValue;
return matchPattern(value, m_ConstantInt(&constantValue));
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value))
return true;
auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp)
return false;
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
}
static LogicalResult
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
if (!sliceOp.hasUnitStride())
return sliceOp.emitOpError() << kind << " requires unit strides";
for (int64_t size : sliceOp.getStaticSizes())
if (ShapedType::isDynamic(size))
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
if (!supported)
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
}
return success();
}
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
BlockArgument laneArg,
StringRef kind) { StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) RankedTensorType sourceType = sliceOp.getSourceType();
RankedTensorType destType = sliceOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
if (!sliceOp.hasUnitStride())
return sliceOp.emitOpError() << kind << " requires unit strides";
for (int64_t size : sliceOp.getStaticSizes())
if (ShapedType::isDynamic(size))
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
if (!supported)
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
}
return success();
}
static LogicalResult verifyTensorChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelIds.empty()) if (channelCount == 0)
return op->emitError() << kind << " must carry at least one chunk"; return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type); auto shapedType = dyn_cast<ShapedType>(type);
@@ -125,40 +190,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op,
return op->emitError() << kind << " requires byte-sized elements"; return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0) if (totalBytes % static_cast<int64_t>(channelCount) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids"; return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
return success(); return success();
} }
static LogicalResult verifyBatchChannelSizes(Operation* op, static LogicalResult
ArrayRef<int64_t> channelIds, verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
ArrayRef<int32_t> sourceCoreIds, if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
ArrayRef<int32_t> targetCoreIds) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op); auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount)) if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch"); return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != static_cast<size_t>(*laneCount)) if (channelCount != static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match parent laneCount"); return op->emitError("channel metadata length must match parent laneCount");
return success(); return success();
} }
static LogicalResult verifyTensorBatchChannelSizes(Operation* op, static LogicalResult verifyTensorBatchChannelSizes(
Type type, Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
ArrayRef<int64_t> channelIds, if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op); auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount)) if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch"); return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0) if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount"; return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type); auto shapedType = dyn_cast<ShapedType>(type);
@@ -169,7 +228,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
if (elementBits <= 0 || elementBits % 8 != 0) if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements"; return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount; int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0) if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane"; return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
@@ -177,28 +236,59 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
return success(); return success();
} }
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) { static Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
<< " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
if (batchOp.getNumResults() == 0) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator()); auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp) if (!yieldOp)
return op->emitError("body must terminate with spat.yield"); return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
if (outputTypes.empty()) {
if (yieldOp.getNumOperands() != 0) if (yieldOp.getNumOperands() != 0)
return op->emitError("body yield must be empty when compute_batch has no results"); return batchOp.emitError("resultless compute_batch body yield must be empty");
} }
else { else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
if (yieldOp.getNumOperands() != 1) return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
return op->emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputTypes[0])
return op->emitError("body yield type must match output type");
} }
BlockArgument laneArg = batchOp.getLaneArgument();
for (auto& bodyOp : block) { for (auto& bodyOp : block) {
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp)) if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane) if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane"); return failure();
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
} }
return success(); return success();
} }
@@ -206,9 +296,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
} // namespace } // namespace
LogicalResult SpatMVMOp::verify() { LogicalResult SpatMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt)) if (failed(matrixShapeOpt))
return emitError("SpatMVMOp was not within a SpatCompute or Core op"); return emitError("weight must be a shaped value");
auto matrixShape = *matrixShapeOpt; auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape(); auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape(); auto outputShape = getOutput().getType().getShape();
@@ -221,9 +311,9 @@ LogicalResult SpatMVMOp::verify() {
} }
LogicalResult SpatVMMOp::verify() { LogicalResult SpatVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt)) if (failed(matrixShapeOpt))
return emitError("SpatVMMOp was not within a SpatCompute or Core op"); return emitError("weight must be a shaped value");
auto matrixShape = *matrixShapeOpt; auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape(); auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape(); auto outputShape = getOutput().getType().getShape();
@@ -347,13 +437,26 @@ LogicalResult verifyComputeResultsUses(Operation* op) {
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>()); return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
}); });
})) { })) {
return op->emitError("ComputeResult used directly inside another Compute" ); return op->emitError("ComputeResult used directly inside another Compute");
} }
return success(); return success();
} }
LogicalResult SpatCompute::verify() { LogicalResult SpatCompute::verify() {
auto& block = getBody().front(); auto& block = getBody().front();
unsigned expectedArgCount = getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) { if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator()); auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp) if (!yieldOp)
@@ -386,9 +489,11 @@ LogicalResult SpatCompute::verify() {
} }
} }
for (auto arg : block.getArguments()) for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
if (arg.use_empty()) if (getInputArgument(inputIndex).use_empty())
return emitError("ComputeOp block argument is not used"); return emitError("ComputeOp block argument is not used");
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
return failure();
if (failed(verifyComputeResultsUses(this->getOperation()))) if (failed(verifyComputeResultsUses(this->getOperation())))
return failure(); return failure();
return success(); return success();
@@ -397,44 +502,46 @@ LogicalResult SpatCompute::verify() {
LogicalResult SpatChannelSendTensorOp::verify() { LogicalResult SpatChannelSendTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(), return verifyTensorChannelSizes(getOperation(),
getInput().getType(), getInput().getType(),
getChannelIds(), getChannelIds().size(),
getSourceCoreIds(), getSourceCoreIds().size(),
getTargetCoreIds(), getTargetCoreIds().size(),
"channel_send_tensor"); "channel_send_tensor");
} }
LogicalResult SpatChannelReceiveTensorOp::verify() { LogicalResult SpatChannelReceiveTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(), return verifyTensorChannelSizes(getOperation(),
getOutput().getType(), getOutput().getType(),
getChannelIds(), getChannelIds().size(),
getSourceCoreIds(), getSourceCoreIds().size(),
getTargetCoreIds(), getTargetCoreIds().size(),
"channel_receive_tensor"); "channel_receive_tensor");
} }
LogicalResult SpatChannelSendBatchOp::verify() { LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
} }
LogicalResult SpatChannelSendTensorBatchOp::verify() { LogicalResult SpatChannelSendTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(), return verifyTensorBatchChannelSizes(getOperation(),
getInput().getType(), getInput().getType(),
getChannelIds(), getChannelIds().size(),
getSourceCoreIds(), getSourceCoreIds().size(),
getTargetCoreIds(), getTargetCoreIds().size(),
"channel_send_tensor_batch"); "channel_send_tensor_batch");
} }
LogicalResult SpatChannelReceiveBatchOp::verify() { LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
} }
LogicalResult SpatChannelReceiveTensorBatchOp::verify() { LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(), return verifyTensorBatchChannelSizes(getOperation(),
getOutput().getType(), getOutput().getType(),
getChannelIds(), getChannelIds().size(),
getSourceCoreIds(), getSourceCoreIds().size(),
getTargetCoreIds(), getTargetCoreIds().size(),
"channel_receive_tensor_batch"); "channel_receive_tensor_batch");
} }
@@ -444,35 +551,6 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("laneCount must be positive"); return emitError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count); auto laneCountSz = static_cast<size_t>(count);
if (getWeights().size() % laneCountSz != 0)
return emitError("number of weights must be a multiple of laneCount");
if (!getInputs().empty() && getInputs().size() != laneCountSz)
return emitError("number of inputs must be either 0 or laneCount");
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
return emitError("number of outputs must be either 0 or laneCount");
size_t weightsPerLane = getWeights().size() / laneCountSz;
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
Type weightType = getWeights()[weightIndex].getType();
for (size_t lane = 1; lane < laneCountSz; ++lane)
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
return emitError("corresponding weights across lanes must have the same type");
}
if (!getInputs().empty()) {
Type inputType = getInputs()[0].getType();
for (Value in : getInputs().drop_front())
if (in.getType() != inputType)
return emitError("all inputs must have the same type");
}
if (!getOutputs().empty()) {
Type outputType = getOutputs()[0].getType();
for (Value out : getOutputs().drop_front())
if (out.getType() != outputType)
return emitError("all outputs must have the same type");
}
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) { if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr); auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
@@ -482,27 +560,64 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("compute_batch coreIds array length must match laneCount"); return emitError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; })) if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
return emitError("compute_batch coreIds values must be non-negative"); return emitError("compute_batch coreIds values must be non-negative");
llvm::SmallDenseSet<int32_t, 8> seenCoreIds; DenseSet<int32_t> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef()) for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second) if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch coreIds values must be distinct"); return emitError("compute_batch coreIds values must be unique");
} }
Block& block = getBody().front(); Block& block = getBody().front();
if (getInputs().empty()) { unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
if (block.getNumArguments() != 0) if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body must have no block arguments when there are no inputs"); return emitError("compute_batch body must have lane, weight, input, and output block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly");
} }
else { for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (block.getNumArguments() != 1) BlockArgument blockArg = getInputArgument(inputIndex);
return emitError("compute_batch body must have exactly one block argument"); if (blockArg.getType() != input.getType())
if (block.getArgument(0).getType() != getInputs()[0].getType()) return emitError("compute_batch input block argument types must match input operand types exactly");
return emitError("body block argument type must match input type"); }
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
BlockArgument blockArg = getOutputArgument(resultIndex);
if (blockArg.getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly");
} }
if (failed(verifyComputeResultsUses(this->getOperation()))) if (failed(verifyComputeResultsUses(this->getOperation())))
return failure(); return failure();
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane); if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
return failure();
return verifyBatchBody(*this, block);
}
LogicalResult SpatInParallelOp::verify() {
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return emitOpError("expected spat.compute_batch parent");
if (batchOp.getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent");
BlockArgument laneArg = batchOp.getLaneArgument();
for (Operation& op : getRegion().front().getOperations()) {
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSliceOp)
return emitOpError("expected only tensor.parallel_insert_slice ops");
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
return failure();
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
for (OpOperand& destination : destinations)
if (!isBatchOutputArgument(batchOp, destination.get()))
return op.emitOpError("may only insert into a compute_batch output block argument");
}
return success();
} }
} // namespace spatial } // namespace spatial
File diff suppressed because it is too large Load Diff
@@ -167,21 +167,20 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size(); return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
} }
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SpatCompute target, ValueRange sourceWeights) { SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights, ValueRange sourceWeights) {
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices; DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
for (auto [weightIndex, weight] : llvm::enumerate(target.getWeights())) for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
targetWeightIndices[weight].push_back(weightIndex); targetWeightIndices[weight].push_back(weightIndex);
DenseMap<Value, size_t> usedSourceWeightOccurrences; DenseMap<Value, size_t> usedSourceWeightOccurrences;
SmallVector<size_t> sourceToTargetIndex; SmallVector<size_t> sourceToTargetIndex;
sourceToTargetIndex.reserve(sourceWeights.size()); sourceToTargetIndex.reserve(sourceWeights.size());
auto targetWeights = target.getWeightsMutable();
for (Value weight : sourceWeights) { for (Value weight : sourceWeights) {
size_t occurrence = usedSourceWeightOccurrences[weight]++; size_t occurrence = usedSourceWeightOccurrences[weight]++;
auto& matchingIndices = targetWeightIndices[weight]; auto& matchingIndices = targetWeightIndices[weight];
if (occurrence >= matchingIndices.size()) { if (occurrence >= matchingIndices.size()) {
size_t newIndex = target.getWeights().size(); size_t newIndex = targetWeights.size();
targetWeights.append(weight); targetWeights.push_back(weight);
matchingIndices.push_back(newIndex); matchingIndices.push_back(newIndex);
sourceToTargetIndex.push_back(newIndex); sourceToTargetIndex.push_back(newIndex);
continue; continue;
@@ -213,37 +212,36 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
auto& computeUse = *compute->getUses().begin(); auto& computeUse = *compute->getUses().begin();
auto child = cast<SpatCompute>(computeUse.getOwner()); auto child = cast<SpatCompute>(computeUse.getOwner());
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber(); auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size(); auto childInputIndex = computeUse.getOperandNumber() - child.getWeights().size();
rewriter.setInsertionPointAfter(compute.getOperation()); rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); SmallVector<Value> mergedWeights(compute.getWeights().begin(), compute.getWeights().end());
newCompute.getProperties().setOperandSegmentSizes( SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(mergedWeights, child.getWeights());
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())}); SmallVector<Value> mergedInputs(compute.getInputs().begin(), compute.getInputs().end());
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), mergedWeights, mergedInputs);
Block* newBody = rewriter.createBlock(&newCompute.getBodyRegion());
for (Value weight : mergedWeights)
newBody->addArgument(weight.getType(), loc);
for (Value input : mergedInputs)
newBody->addArgument(input.getType(), loc);
IRMapping mapper; IRMapping mapper;
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(newCompute, child.getWeights()); for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs()))
mapper.map(compute.getInputArgument(inputIndex), newCompute.getInputArgument(inputIndex));
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights()))
mapper.map(weight, *std::next(newCompute.getWeights().begin(), childWeightToNewIndex[oldIndex])); mapper.map(child.getWeightArgument(oldIndex), newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]));
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); rewriter.setInsertionPointToEnd(newBody);
auto newTerminator = newCompute.getBody().front().getTerminator(); auto computeYield = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult)); for (Operation& op : compute.getBody().front().without_terminator())
newTerminator->erase(); rewriter.clone(op, mapper);
mapper.map(child.getInputArgument(childInputIndex), mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); rewriter.setInsertionPointToEnd(newBody);
auto remapWeightIndex = [&](auto weightedOp) { for (auto& op : child.getBody().front())
auto oldIndex = weightedOp.getWeightIndex(); rewriter.clone(op, mapper);
assert(static_cast<size_t>(oldIndex) < childWeightToNewIndex.size() && "weight index out of range");
weightedOp.setWeightIndex(childWeightToNewIndex[oldIndex]);
};
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto weightedMvmOp = dyn_cast<spatial::SpatMVMOp>(newInst))
remapWeightIndex(weightedMvmOp);
if (auto weightedVmmOp = dyn_cast<spatial::SpatVMMOp>(newInst))
remapWeightIndex(weightedVmmOp);
}
child.replaceAllUsesWith(newCompute); child.replaceAllUsesWith(newCompute);
toErase.insert(child); toErase.insert(child);
@@ -2,6 +2,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
@@ -61,6 +62,66 @@ std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase"; static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
static FailureOr<int64_t> getConstantI64Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return constantValue.getSExtValue();
}
static FailureOr<int32_t> getConstantI32Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int64_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int32_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) { std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName)) if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
return static_cast<uint64_t>(phaseAttr.getInt()); return static_cast<uint64_t>(phaseAttr.getInt());
@@ -206,8 +267,215 @@ bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
} }
struct BatchYieldInfo {
Value yieldedValue;
tensor::ParallelInsertSliceOp insertSlice;
};
static bool isHostOnlyBatchResultUser(Operation* user) {
return isa<func::ReturnOp,
spatial::SpatConcatOp,
tensor::ExtractSliceOp,
tensor::CastOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp>(user);
}
static FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> collectBatchYieldInfo(SpatComputeBatch batchOp) {
Block& block = batchOp.getBody().front();
auto inParallel = dyn_cast<spatial::SpatInParallelOp>(block.getTerminator());
if (!inParallel)
return failure();
DenseMap<BlockArgument, BatchYieldInfo> batchYieldByOutputArg;
for (Operation& op : inParallel.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSlice)
return failure();
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &block)
return failure();
batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice};
}
return batchYieldByOutputArg;
}
static FailureOr<SpatComputeBatch> cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) {
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return failure();
Block& oldBlock = batchOp.getBody().front();
rewriter.setInsertionPoint(batchOp);
auto newBatch = SpatComputeBatch::create(rewriter,
batchOp.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(batchOp.getLaneCount()),
batchOp.getWeights(),
batchOp.getInputs());
newBatch.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
blockArgTypes.push_back(batchOp.getLaneArgument().getType());
blockArgLocs.push_back(batchOp.getLaneArgument().getLoc());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) {
blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType());
blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc());
}
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) {
blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType());
blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc());
}
Block* newBlock =
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex));
for (Operation& op : oldBlock.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(oldResult, newResult);
}
return newBatch;
}
static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
for (auto batchOp : batches) {
if (batchOp.getNumResults() == 0)
continue;
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return batchOp.emitOpError("missing coreIds while materializing batch result communication");
FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> batchYieldInfo = collectBatchYieldInfo(batchOp);
if (failed(batchYieldInfo))
return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body");
FailureOr<SpatComputeBatch> newBatch = cloneBatchAsResultless(batchOp, rewriter);
if (failed(newBatch))
return batchOp.emitOpError("failed to clone resultful compute_batch as resultless");
Block& oldBlock = batchOp.getBody().front();
Block& newBlock = newBatch->getBody().front();
IRMapping mapper;
mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex));
auto oldIt = oldBlock.begin();
auto newIt = newBlock.begin();
for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt)
for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults()))
mapper.map(oldResult, newResult);
SmallVector<int32_t> sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
rewriter.setInsertionPointToEnd(&newBlock);
for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) {
BlockArgument outputArg = batchOp.getOutputArgument(resultIndex);
auto yieldInfoIt = batchYieldInfo->find(outputArg);
if (yieldInfoIt == batchYieldInfo->end())
return batchOp.emitOpError(
"missing yielded value for compute_batch result during communication materialization");
Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue);
DenseMap<int32_t, SmallVector<OpOperand*>> computeUsesByTargetCore;
SmallVector<OpOperand*> hostUses;
for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) {
if (auto computeOp = dyn_cast<SpatCompute>(use.getOwner())) {
auto coreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
return batchOp.emitOpError("compute user of compute_batch result is missing coreId");
computeUsesByTargetCore[static_cast<int32_t>(coreIdAttr.getInt())].push_back(&use);
continue;
}
if (isHostOnlyBatchResultUser(use.getOwner())) {
hostUses.push_back(&use);
continue;
}
return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization")
<< ": " << use.getOwner()->getName();
}
auto createReceiveForUses = [&](ArrayRef<OpOperand*> uses, ArrayRef<int32_t> targetCoreIds) -> LogicalResult {
if (uses.empty())
return success();
SmallVector<int64_t> channelIds;
channelIds.reserve(sourceCoreIds.size());
for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds)
channelIds.push_back(nextChannelId++);
SmallVector<Value> sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
SmallVector<Value> sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
SmallVector<Value> sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter,
batchOp.getLoc(),
sendChannelIdValues,
sendSourceCoreIdValues,
sendTargetCoreIdValues,
mappedYieldedValue);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(newBatch->getOperation());
SmallVector<Value> receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
SmallVector<Value> receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
SmallVector<Value> receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter,
batchOp.getLoc(),
batchOp.getResult(resultIndex).getType(),
receiveChannelIdValues,
receiveSourceCoreIdValues,
receiveTargetCoreIdValues);
for (OpOperand* use : uses)
use->set(received.getOutput());
rewriter.setInsertionPointToEnd(&newBlock);
return success();
};
for (auto& [targetCoreId, uses] : computeUsesByTargetCore) {
SmallVector<int32_t> targetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), targetCoreId);
if (failed(createReceiveForUses(uses, targetCoreIds)))
return failure();
}
if (!hostUses.empty()) {
SmallVector<int32_t> hostTargetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), 0);
if (failed(createReceiveForUses(hostUses, hostTargetCoreIds)))
return failure();
}
}
rewriter.setInsertionPointToEnd(&newBlock);
spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {});
rewriter.eraseOp(batchOp);
}
return success();
}
void rebatchEquivalentComputes(func::FuncOp funcOp) { void rebatchEquivalentComputes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>()); SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
DenseSet<Operation*> consumed; DenseSet<Operation*> consumed;
DenseMap<Operation*, size_t> computeOrder; DenseMap<Operation*, size_t> computeOrder;
@@ -316,8 +584,10 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
entries.reserve(group.size()); entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) { for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]); auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
entries.push_back( BatchReceiveEntry entry;
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()}); if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
return;
entries.push_back(entry);
++opIts[groupIndex]; ++opIts[groupIndex];
} }
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
@@ -331,12 +601,15 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId)); sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId)); targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
} }
SmallVector<Value> channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder);
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter, auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
receiveOp.getLoc(), receiveOp.getLoc(),
receiveOp.getOutput().getType(), receiveOp.getOutput().getType(),
rewriter.getDenseI64ArrayAttr(channelIds), channelIdValues,
rewriter.getDenseI32ArrayAttr(sourceCoreIds), sourceCoreIdValues,
rewriter.getDenseI32ArrayAttr(targetCoreIds)); targetCoreIdValues);
mapper.map(receiveOp.getOutput(), batchReceive.getOutput()); mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
continue; continue;
} }
@@ -351,7 +624,10 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
entries.reserve(group.size()); entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) { for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]); auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()}); BatchSendEntry entry;
if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
return;
entries.push_back(entry);
++opIts[groupIndex]; ++opIts[groupIndex];
} }
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
@@ -365,11 +641,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp) {
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId)); sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId)); targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
} }
SmallVector<Value> channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter, spatial::SpatChannelSendBatchOp::create(rewriter,
sendOp.getLoc(), sendOp.getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds), channelIdValues,
rewriter.getDenseI32ArrayAttr(sourceCoreIds), sourceCoreIdValues,
rewriter.getDenseI32ArrayAttr(targetCoreIds), targetCoreIdValues,
mapper.lookup(sendOp.getInput())); mapper.lookup(sendOp.getInput()));
continue; continue;
} }
@@ -452,6 +731,11 @@ LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextC
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops"); ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
cleanupDeadPackingOps(funcOp); cleanupDeadPackingOps(funcOp);
} }
{
ScopedMergePhaseTimer timer("materialize-batch-result-communication");
if (failed(materializeBatchResultCommunication(funcOp, nextChannelId)))
return failure();
}
return success(); return success();
} }
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
@@ -30,7 +31,7 @@ enum class RegularStepKind {
struct RegularStep { struct RegularStep {
RegularStepKind kind; RegularStepKind kind;
int32_t weightIndex = 0; Value weight;
Value invariantOperand; Value invariantOperand;
Type resultType; Type resultType;
}; };
@@ -73,15 +74,90 @@ static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId); return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
} }
static void appendChannelAttrs(SmallVectorImpl<int64_t>& channelIds, static FailureOr<int64_t> getConstantI64Value(Value value) {
SmallVectorImpl<int32_t>& sourceCoreIds, APInt constantValue;
SmallVectorImpl<int32_t>& targetCoreIds, if (!matchPattern(value, m_ConstantInt(&constantValue)))
uint64_t channelId, return failure();
uint32_t sourceCoreId, return constantValue.getSExtValue();
uint32_t targetCoreId) { }
channelIds.push_back(static_cast<int64_t>(channelId));
sourceCoreIds.push_back(static_cast<int32_t>(sourceCoreId)); static FailureOr<int32_t> getConstantI32Value(Value value) {
targetCoreIds.push_back(static_cast<int32_t>(targetCoreId)); APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int64_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int32_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
static SmallVector<Operation*> getScalarChannelMetadataDefs(Operation* channelOp, unsigned metadataOperandCount) {
SmallVector<Operation*> defs;
defs.reserve(metadataOperandCount);
for (unsigned operandIndex = 0; operandIndex < metadataOperandCount; ++operandIndex) {
Operation* def = channelOp->getOperand(operandIndex).getDefiningOp();
auto constantOp = dyn_cast_or_null<arith::ConstantOp>(def);
if (!constantOp || def->getBlock() != channelOp->getBlock())
continue;
defs.push_back(def);
}
llvm::sort(defs, [](Operation* lhs, Operation* rhs) { return lhs->isBeforeInBlock(rhs); });
return defs;
}
static void moveScalarChannelBundleBefore(Operation* channelOp, Operation* insertionPoint) {
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
metadataDef->moveBefore(insertionPoint);
channelOp->moveBefore(insertionPoint);
}
static void moveScalarChannelBundleBefore(Operation* channelOp, Block* block, Block::iterator insertionPoint) {
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
metadataDef->moveBefore(block, insertionPoint);
channelOp->moveBefore(block, insertionPoint);
} }
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) { static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
@@ -196,7 +272,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
} }
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand return lhs.kind == rhs.kind && lhs.weight == rhs.weight && lhs.invariantOperand == rhs.invariantOperand
&& lhs.resultType == rhs.resultType; && lhs.resultType == rhs.resultType;
} }
@@ -227,8 +303,7 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
chunk.input = startOp.getInput(); chunk.input = startOp.getInput();
chunk.output = startOp.getOutput(); chunk.output = startOp.getOutput();
chunk.ops.push_back(startOp.getOperation()); chunk.ops.push_back(startOp.getOperation());
chunk.steps.push_back( chunk.steps.push_back({RegularStepKind::Wvmm, startOp.getWeight(), Value(), startOp.getOutput().getType()});
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
Value currentValue = startOp.getOutput(); Value currentValue = startOp.getOutput();
while (currentValue.hasOneUse()) { while (currentValue.hasOneUse()) {
@@ -241,9 +316,9 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
break; break;
if (vaddOp.getLhs() == currentValue) if (vaddOp.getLhs() == currentValue)
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()}); chunk.steps.push_back({RegularStepKind::VAddLhs, Value(), vaddOp.getRhs(), vaddOp.getOutput().getType()});
else if (vaddOp.getRhs() == currentValue) else if (vaddOp.getRhs() == currentValue)
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()}); chunk.steps.push_back({RegularStepKind::VAddRhs, Value(), vaddOp.getLhs(), vaddOp.getOutput().getType()});
else else
break; break;
@@ -255,7 +330,8 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
return chunk; return chunk;
} }
static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) { static RegularCompactionResult
compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run, OperationFolder& constantFolder) {
assert(!run.empty() && "expected a non-empty regular chunk run"); assert(!run.empty() && "expected a non-empty regular chunk run");
const RegularChunk& anchorChunk = run.front(); const RegularChunk& anchorChunk = run.front();
RegularCompactionResult result; RegularCompactionResult result;
@@ -275,9 +351,9 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size())); auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
auto packedInit = tensor::EmptyOp::create( auto packedInit = tensor::EmptyOp::create(
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType()); rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0); auto zero = getOrCreateHostIndexConstant(anchorChunk.startOp, 0, constantFolder);
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size()); auto upper = getOrCreateHostIndexConstant(anchorChunk.startOp, static_cast<int64_t>(run.size()), constantFolder);
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1); auto step = getOrCreateHostIndexConstant(anchorChunk.startOp, 1, constantFolder);
auto loop = auto loop =
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()}); scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
@@ -290,8 +366,7 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
Value inputRowOffset = iv; Value inputRowOffset = iv;
if (inputType.getDimSize(0) != 1) { if (inputType.getDimSize(0) != 1) {
auto rowsPerValue = auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, inputType.getDimSize(0), constantFolder);
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue); inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
} }
@@ -320,8 +395,7 @@ static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, Arra
Value mappedOutput = mapping.lookup(anchorChunk.output); Value mappedOutput = mapping.lookup(anchorChunk.output);
Value outputRowOffset = iv; Value outputRowOffset = iv;
if (outputType.getDimSize(0) != 1) { if (outputType.getDimSize(0) != 1) {
auto rowsPerValue = auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, outputType.getDimSize(0), constantFolder);
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue); outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
} }
@@ -389,35 +463,50 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
Block& block = compute.getBody().front(); Block& block = compute.getBody().front();
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves; SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint; DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
Operation* firstForwardedSend = nullptr;
for (Operation& op : block) { for (Operation& op : block) {
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) { if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
if (sendOp.getSourceCoreId() == static_cast<uint32_t>(coreId) uint64_t channelId = 0;
&& isForwardedChannelPayload(sendOp.getInput(), block)) { uint32_t sourceCoreId = 0;
uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId()); uint32_t targetCoreId = 0;
if (getScalarChannelMetadata(sendOp, channelId, sourceCoreId, targetCoreId)
&& sourceCoreId == static_cast<uint32_t>(coreId) && isForwardedChannelPayload(sendOp.getInput(), block)) {
if (!firstForwardedSend)
firstForwardedSend = sendOp.getOperation();
uint64_t key = getEndpointKey(sourceCoreId, targetCoreId);
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation()); firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
} }
continue; continue;
} }
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op); auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId) uint64_t channelId = 0;
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) { uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|| targetCoreId != static_cast<uint32_t>(coreId) || sourceCoreId >= static_cast<uint32_t>(coreId)) {
continue; continue;
} }
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), receiveOp.getSourceCoreId()); uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), sourceCoreId);
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key); auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
if (firstMatchingSend != firstForwardedSendByEndpoint.end()) if (firstMatchingSend != firstForwardedSendByEndpoint.end())
moves.push_back({receiveOp, firstMatchingSend->second}); moves.push_back({receiveOp, firstMatchingSend->second});
else if (firstForwardedSend && firstForwardedSend->isBeforeInBlock(receiveOp))
moves.push_back({receiveOp, firstForwardedSend});
} }
for (auto [receiveOp, insertionPoint] : moves) for (auto [receiveOp, insertionPoint] : moves)
receiveOp->moveBefore(insertionPoint); moveScalarChannelBundleBefore(receiveOp, insertionPoint);
for (auto it = block.begin(); it != block.end();) { for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it); auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) { uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|| sourceCoreId >= static_cast<uint32_t>(coreId)) {
++it; ++it;
continue; continue;
} }
@@ -425,18 +514,32 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
Type outputType = receiveOp.getOutput().getType(); Type outputType = receiveOp.getOutput().getType();
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>( auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
it, block.end(), [&](spatial::SpatChannelReceiveOp current) { it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
uint64_t currentChannelId = 0;
uint32_t currentSourceCoreId = 0;
uint32_t currentTargetCoreId = 0;
return current.getOutput().getType() == outputType return current.getOutput().getType() == outputType
&& current.getSourceCoreId() < static_cast<uint32_t>(coreId); && getScalarChannelMetadata(current, currentChannelId, currentSourceCoreId, currentTargetCoreId)
&& currentSourceCoreId < static_cast<uint32_t>(coreId);
}); });
if (run.ops.size() > 1) { if (run.ops.size() > 1) {
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops); SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) { llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
return lhs.getSourceCoreId() > rhs.getSourceCoreId(); uint64_t lhsChannelId = 0;
uint32_t lhsSourceCoreId = 0;
uint32_t lhsTargetCoreId = 0;
uint64_t rhsChannelId = 0;
uint32_t rhsSourceCoreId = 0;
uint32_t rhsTargetCoreId = 0;
bool lhsHasMetadata = getScalarChannelMetadata(lhs, lhsChannelId, lhsSourceCoreId, lhsTargetCoreId);
bool rhsHasMetadata = getScalarChannelMetadata(rhs, rhsChannelId, rhsSourceCoreId, rhsTargetCoreId);
if (!lhsHasMetadata || !rhsHasMetadata)
return false;
return lhsSourceCoreId > rhsSourceCoreId;
}); });
Block::iterator insertIt = run.end; Block::iterator insertIt = run.end;
for (auto op : sorted) for (auto op : sorted)
op->moveBefore(&block, insertIt); moveScalarChannelBundleBefore(op, &block, insertIt);
} }
it = run.end; it = run.end;
@@ -446,6 +549,7 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) { for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front(); Block& block = compute.getBody().front();
@@ -461,7 +565,14 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
bool hasRepeatedEndpoint = false; bool hasRepeatedEndpoint = false;
DenseSet<uint64_t> seenEndpoints; DenseSet<uint64_t> seenEndpoints;
for (auto op : run.ops) { for (auto op : run.ops) {
uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId()); uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
hasRepeatedEndpoint = true;
break;
}
uint64_t endpointKey = getEndpointKey(sourceCoreId, targetCoreId);
if (!seenEndpoints.insert(endpointKey).second) { if (!seenEndpoints.insert(endpointKey).second) {
hasRepeatedEndpoint = true; hasRepeatedEndpoint = true;
break; break;
@@ -478,8 +589,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
}; };
SmallVector<ReceiveEntry> sortedEntries; SmallVector<ReceiveEntry> sortedEntries;
sortedEntries.reserve(run.ops.size()); sortedEntries.reserve(run.ops.size());
for (auto [originalIndex, op] : llvm::enumerate(run.ops)) for (auto [originalIndex, op] : llvm::enumerate(run.ops)) {
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
sortedEntries.clear();
break;
}
sortedEntries.push_back({op, originalIndex, sourceCoreId, targetCoreId, channelId});
}
if (sortedEntries.empty()) {
++it;
continue;
}
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<int32_t> sourceCoreIds;
@@ -488,8 +611,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
sourceCoreIds.reserve(sortedEntries.size()); sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size()); targetCoreIds.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) { for (ReceiveEntry& entry : sortedEntries) {
appendChannelAttrs( channelIds.push_back(static_cast<int64_t>(entry.channelId));
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId); sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
} }
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType()); auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
@@ -506,13 +630,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
: RankedTensorType {}; : RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.ops.front()); rewriter.setInsertionPoint(run.ops.front());
auto compactReceive = SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
spatial::SpatChannelReceiveTensorOp::create(rewriter, SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
run.ops.front().getLoc(), SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
packedType, auto compactReceive = spatial::SpatChannelReceiveTensorOp::create(
rewriter.getDenseI64ArrayAttr(channelIds), rewriter, run.ops.front().getLoc(), packedType, channelIdValues, sourceCoreIdValues, targetCoreIdValues);
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
if (concatOp && concatPackedType) { if (concatOp && concatPackedType) {
replaceConcatRunWithPackedValue(concatOp, replaceConcatRunWithPackedValue(concatOp,
concatStartIndex, concatStartIndex,
@@ -551,8 +673,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
}; };
SmallVector<SendEntry> sortedEntries; SmallVector<SendEntry> sortedEntries;
sortedEntries.reserve(run.ops.size()); sortedEntries.reserve(run.ops.size());
for (auto op : run.ops) for (auto op : run.ops) {
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
sortedEntries.clear();
break;
}
sortedEntries.push_back({op, sourceCoreId, targetCoreId, channelId});
}
if (sortedEntries.empty()) {
++it;
continue;
}
SmallVector<int64_t> channelIds; SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<int32_t> sourceCoreIds;
@@ -563,20 +697,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
targetCoreIds.reserve(sortedEntries.size()); targetCoreIds.reserve(sortedEntries.size());
inputs.reserve(sortedEntries.size()); inputs.reserve(sortedEntries.size());
for (SendEntry& entry : sortedEntries) { for (SendEntry& entry : sortedEntries) {
appendChannelAttrs( channelIds.push_back(static_cast<int64_t>(entry.channelId));
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId); sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
inputs.push_back(entry.op.getInput()); inputs.push_back(entry.op.getInput());
} }
rewriter.setInsertionPoint(run.ops.front()); rewriter.setInsertionPoint(run.ops.front());
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) { if (packedInput) {
spatial::SpatChannelSendTensorOp::create(rewriter, SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
run.ops.front().getLoc(), SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
rewriter.getDenseI64ArrayAttr(channelIds), SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
rewriter.getDenseI32ArrayAttr(sourceCoreIds), spatial::SpatChannelSendTensorOp::create(
rewriter.getDenseI32ArrayAttr(targetCoreIds), rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
packedInput);
for (auto op : run.ops) for (auto op : run.ops)
rewriter.eraseOp(op); rewriter.eraseOp(op);
@@ -606,9 +740,9 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
}); });
if (run.ops.size() > 1) { if (run.ops.size() > 1) {
SmallVector<int64_t> channelIds; SmallVector<Value> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<Value> sourceCoreIds;
SmallVector<int32_t> targetCoreIds; SmallVector<Value> targetCoreIds;
for (auto op : run.ops) { for (auto op : run.ops) {
llvm::append_range(channelIds, op.getChannelIds()); llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
@@ -629,13 +763,8 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
: RankedTensorType {}; : RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.ops.front()); rewriter.setInsertionPoint(run.ops.front());
auto compactReceive = auto compactReceive = spatial::SpatChannelReceiveTensorBatchOp::create(
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter, rewriter, run.ops.front().getLoc(), packedType, channelIds, sourceCoreIds, targetCoreIds);
run.ops.front().getLoc(),
packedType,
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
if (concatOp && concatPackedType) { if (concatOp && concatPackedType) {
replaceConcatRunWithPackedValue( replaceConcatRunWithPackedValue(
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter); concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
@@ -663,9 +792,9 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
}); });
if (run.ops.size() > 1) { if (run.ops.size() > 1) {
SmallVector<int64_t> channelIds; SmallVector<Value> channelIds;
SmallVector<int32_t> sourceCoreIds; SmallVector<Value> sourceCoreIds;
SmallVector<int32_t> targetCoreIds; SmallVector<Value> targetCoreIds;
SmallVector<Value> inputs; SmallVector<Value> inputs;
inputs.reserve(run.ops.size()); inputs.reserve(run.ops.size());
for (auto op : run.ops) { for (auto op : run.ops) {
@@ -678,12 +807,8 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
rewriter.setInsertionPoint(run.ops.front()); rewriter.setInsertionPoint(run.ops.front());
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) { if (packedInput) {
spatial::SpatChannelSendTensorBatchOp::create(rewriter, spatial::SpatChannelSendTensorBatchOp::create(
run.ops.front().getLoc(), rewriter, run.ops.front().getLoc(), channelIds, sourceCoreIds, targetCoreIds, packedInput);
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
packedInput);
for (auto op : run.ops) for (auto op : run.ops)
rewriter.eraseOp(op); rewriter.eraseOp(op);
@@ -700,6 +825,7 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
void compactRegularOpRuns(func::FuncOp funcOp) { void compactRegularOpRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
auto compactInBlock = [&](Block& block) { auto compactInBlock = [&](Block& block) {
for (auto it = block.begin(); it != block.end();) { for (auto it = block.begin(); it != block.end();) {
@@ -740,7 +866,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
for (const RegularChunk& chunk : run) for (const RegularChunk& chunk : run)
originalOpCount += chunk.ops.size(); originalOpCount += chunk.ops.size();
RegularCompactionResult result = compactRegularChunkRun(rewriter, run); RegularCompactionResult result = compactRegularChunkRun(rewriter, run, constantFolder);
if (result.changed) { if (result.changed) {
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run"); assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
if (!result.resumeAfter) { if (!result.resumeAfter) {
@@ -763,6 +889,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
void compactRowWiseWvmmRuns(func::FuncOp funcOp) { void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) { for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front(); Block& block = compute.getBody().front();
@@ -784,7 +911,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber()); int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) { auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
if (current.getWeightIndex() != wvmmOp.getWeightIndex() if (current.getWeight() != wvmmOp.getWeight()
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp || current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|| current.getInput().getType() != wvmmOp.getInput().getType() || current.getInput().getType() != wvmmOp.getInput().getType()
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) || current.getOutput().getType() != wvmmOp.getOutput().getType())
@@ -851,9 +978,9 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType()); auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
rewriter.setInsertionPoint(run.ops.front()); rewriter.setInsertionPoint(run.ops.front());
auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0); auto zero = getOrCreateHostIndexConstant(run.ops.front(), 0, constantFolder);
auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength); auto upper = getOrCreateHostIndexConstant(run.ops.front(), runLength, constantFolder);
auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1); auto step = getOrCreateHostIndexConstant(run.ops.front(), 1, constantFolder);
auto packedInit = auto packedInit =
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType()); tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
auto loop = auto loop =
@@ -868,7 +995,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
Value sourceRow = iv; Value sourceRow = iv;
if (firstRow != 0) { if (firstRow != 0) {
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow); auto firstRowValue = getOrCreateHostIndexConstant(run.ops.front(), firstRow, constantFolder);
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue); sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
} }
@@ -883,7 +1010,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
extractSizes, extractSizes,
extractStrides); extractStrides);
auto loopWvmm = spatial::SpatVMMOp::create( auto loopWvmm = spatial::SpatVMMOp::create(
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult()); rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeight(), extractedRow.getResult());
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)}; SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
@@ -23,31 +23,31 @@ using namespace mlir;
namespace { namespace {
Weight getComputeBodyWeight(Region &body) { Weight getComputeBodyWeight(Region& body) {
constexpr Weight kOperationWeight = 100; constexpr Weight kOperationWeight = 100;
Weight numOperations = 0; Weight numOperations = 0;
for (auto &block : body) for (auto& block : body)
for ([[maybe_unused]] auto &op : block) for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1)); numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight); return checkedMultiply(numOperations, kOperationWeight);
} }
CrossbarUsage getComputeBodyCrossbarUsage(Region &body) { CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
CrossbarUsage crossbarUsage = 0; CrossbarUsage crossbarUsage = 0;
for (auto &block : body) for (auto& block : body)
for (auto &op : block) for (auto& op : block)
if (isa<SpatVMMOp>(op)) if (isa<SpatVMMOp>(op))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1)); crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage; return crossbarUsage;
} }
bool isUsedAsWeightOnly(Operation *producerOp) { bool isUsedAsWeightOnly(Operation* producerOp) {
if (producerOp->getNumResults() == 0) if (producerOp->getNumResults() == 0)
return false; return false;
for (Value result : producerOp->getResults()) { for (Value result : producerOp->getResults()) {
if (result.use_empty()) if (result.use_empty())
return false; return false;
for (Operation *user : result.getUsers()) { for (Operation* user : result.getUsers()) {
if (auto compute = dyn_cast<SpatCompute>(user)) { if (auto compute = dyn_cast<SpatCompute>(user)) {
if (!llvm::is_contained(compute.getWeights(), result)) if (!llvm::is_contained(compute.getWeights(), result))
return false; return false;
@@ -66,7 +66,7 @@ bool isUsedAsWeightOnly(Operation *producerOp) {
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) { std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights; llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (const ComputeGraphEdge &edge : edges) { for (const ComputeGraphEdge& edge : edges) {
if (edge.source == edge.target) if (edge.source == edge.target)
continue; continue;
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost); auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
@@ -76,9 +76,9 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
std::vector<ComputeGraphEdge> aggregatedEdges; std::vector<ComputeGraphEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size()); aggregatedEdges.reserve(edgeWeights.size());
for (const auto &[key, weight] : edgeWeights) for (const auto& [key, weight] : edgeWeights)
aggregatedEdges.push_back({key.first, key.second, weight}); aggregatedEdges.push_back({key.first, key.second, weight});
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) { llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) {
if (lhs.source != rhs.source) if (lhs.source != rhs.source)
return lhs.source < rhs.source; return lhs.source < rhs.source;
return lhs.target < rhs.target; return lhs.target < rhs.target;
@@ -88,33 +88,33 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
} // namespace } // namespace
Weight getComputeInstanceWeight(const ComputeInstance &instance) { Weight getComputeInstanceWeight(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op)) if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeWeight(spatCompute); return getSpatComputeWeight(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op); auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount)); return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
} }
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) { CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op)) if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeCrossbarUsage(spatCompute); return getSpatComputeCrossbarUsage(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op); auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
static_cast<CrossbarUsage>(instance.laneCount));
} }
ComputeGraph buildComputeGraph(Operation *entryOp) { ComputeGraph buildComputeGraph(Operation* entryOp) {
ComputeGraph graph; ComputeGraph graph;
for (Region &region : entryOp->getRegions()) { for (Region& region : entryOp->getRegions()) {
for (Block &block : region) { for (Block& block : region) {
for (Operation &op : block) { for (Operation& op : block) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) { if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
if (isUsedAsWeightOnly(spatCompute.getOperation())) if (isUsedAsWeightOnly(spatCompute.getOperation()))
continue; continue;
ComputeInstance instance {spatCompute.getOperation(), 0, 1}; ComputeInstance instance {spatCompute.getOperation(), 0, 1};
size_t index = graph.nodes.size(); size_t index = graph.nodes.size();
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index}); graph.nodes.push_back(
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index; graph.instanceToIndex[instance] = index;
continue; continue;
} }
@@ -135,9 +135,21 @@ ComputeGraph buildComputeGraph(Operation *entryOp) {
} }
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges; llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) { for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
for (Value input : getComputeInstanceInputs(node.instance)) { for (Value input : getComputeInstanceInputs(node.instance)) {
auto producerInstance = getComputeProducerInstance(input); if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back(
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
}
continue;
}
auto producerInstance = getComputeProducerInstance(input, &node.instance);
if (!producerInstance) if (!producerInstance)
continue; continue;
auto producerIt = graph.instanceToIndex.find(*producerInstance); auto producerIt = graph.instanceToIndex.find(*producerInstance);
@@ -152,7 +164,7 @@ ComputeGraph buildComputeGraph(Operation *entryOp) {
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end()); graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
graph.successors.assign(graph.nodes.size(), {}); graph.successors.assign(graph.nodes.size(), {});
graph.predecessors.assign(graph.nodes.size(), {}); graph.predecessors.assign(graph.nodes.size(), {});
for (const ComputeGraphEdge &edge : graph.edges) { for (const ComputeGraphEdge& edge : graph.edges) {
graph.successors[edge.source].push_back({edge.target, edge.transferCost}); graph.successors[edge.source].push_back({edge.target, edge.transferCost});
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost}); graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
} }
@@ -160,7 +172,7 @@ ComputeGraph buildComputeGraph(Operation *entryOp) {
return graph; return graph;
} }
bool verifyAcyclic(const ComputeGraph &graph) { bool verifyAcyclic(const ComputeGraph& graph) {
std::vector<size_t> remainingParents(graph.nodes.size(), 0); std::vector<size_t> remainingParents(graph.nodes.size(), 0);
std::queue<size_t> readyNodes; std::queue<size_t> readyNodes;
for (size_t node = 0; node < graph.nodes.size(); ++node) { for (size_t node = 0; node < graph.nodes.size(); ++node) {
@@ -174,7 +186,7 @@ bool verifyAcyclic(const ComputeGraph &graph) {
size_t node = readyNodes.front(); size_t node = readyNodes.front();
readyNodes.pop(); readyNodes.pop();
++visited; ++visited;
for (const auto &[child, weight] : graph.successors[node]) { for (const auto& [child, weight] : graph.successors[node]) {
(void) weight; (void) weight;
assert(remainingParents[child] > 0 && "remaining parent count underflow"); assert(remainingParents[child] > 0 && "remaining parent count underflow");
if (--remainingParents[child] == 0) if (--remainingParents[child] == 0)
@@ -1,6 +1,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include <limits> #include <limits>
#include <optional>
#include "ComputeInstanceUtils.hpp" #include "ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -18,40 +20,81 @@ size_t getSchedulingCpuBudget() {
size_t getBatchChunkTargetCount(int32_t laneCount) { size_t getBatchChunkTargetCount(int32_t laneCount) {
assert(laneCount > 0 && "laneCount must be positive"); assert(laneCount > 0 && "laneCount must be positive");
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget())); return static_cast<size_t>(laneCount);
} }
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
size_t totalLanes = batch.getLaneCount(); assert(chunkIndex < static_cast<size_t>(batch.getLaneCount()) && "chunkIndex out of range");
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); return {batch.getOperation(), static_cast<uint32_t>(chunkIndex), 1};
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
} }
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) { ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
size_t totalLanes = batch.getLaneCount(); assert(lane < static_cast<uint32_t>(batch.getLaneCount()) && "lane out of range");
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); return {batch.getOperation(), lane, 1};
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
size_t chunkIndex = 0;
if (static_cast<size_t>(lane) < largeChunkSpan)
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
else
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
return getBatchChunkForIndex(batch, chunkIndex);
} }
std::optional<ProducerValueRef> getProducerValueRef(Value value) { static std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
Operation *op = value.getDefiningOp(); if (extract.getMixedOffsets().empty())
return std::nullopt;
OpFoldResult offset = extract.getMixedOffsets().front();
if (Attribute attr = llvm::dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (!intAttr || intAttr.getInt() < 0)
return std::nullopt;
return static_cast<uint32_t>(intAttr.getInt());
}
Value offsetValue = llvm::cast<Value>(offset);
if (auto constantIndex = offsetValue.getDefiningOp<arith::ConstantIndexOp>()) {
if (constantIndex.value() < 0)
return std::nullopt;
return static_cast<uint32_t>(constantIndex.value());
}
return std::nullopt;
}
static std::optional<ProducerValueRef> getResultfulBatchProducerValueRef(SpatComputeBatch batch,
const ComputeInstance* consumerInstance) {
if (!consumerInstance)
return std::nullopt;
if (!isa<SpatComputeBatch>(consumerInstance->op))
return std::nullopt;
if (consumerInstance->laneStart + consumerInstance->laneCount > static_cast<uint32_t>(batch.getLaneCount()))
return std::nullopt;
return ProducerValueRef {
{batch.getOperation(), consumerInstance->laneStart, consumerInstance->laneCount},
0
};
}
std::optional<ProducerValueRef> getProducerValueRef(Value value, const ComputeInstance* consumerInstance) {
Operation* op = value.getDefiningOp();
if (!op) if (!op)
return std::nullopt; return std::nullopt;
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
Value source = extract.getSource();
auto batch = dyn_cast_or_null<SpatComputeBatch>(source.getDefiningOp());
if (batch && batch.getNumResults() != 0) {
if (std::optional<uint32_t> lane = getConstantExtractLane(extract)) {
if (*lane >= static_cast<uint32_t>(batch.getLaneCount()))
return std::nullopt;
return ProducerValueRef {
{batch.getOperation(), *lane, 1},
0
};
}
return getResultfulBatchProducerValueRef(batch, consumerInstance);
}
value = source;
op = value.getDefiningOp();
if (!op)
return std::nullopt;
}
if (auto compute = dyn_cast<SpatCompute>(op)) { if (auto compute = dyn_cast<SpatCompute>(op)) {
return ProducerValueRef { return ProducerValueRef {
ComputeInstance {compute.getOperation(), 0, 1}, ComputeInstance {compute.getOperation(), 0, 1},
@@ -60,6 +103,8 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
} }
if (auto batch = dyn_cast<SpatComputeBatch>(op)) { if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
if (batch.getNumResults() != 0)
return getResultfulBatchProducerValueRef(batch, consumerInstance);
uint32_t lane = cast<OpResult>(value).getResultNumber(); uint32_t lane = cast<OpResult>(value).getResultNumber();
ComputeInstance instance = getBatchChunkForLane(batch, lane); ComputeInstance instance = getBatchChunkForLane(batch, lane);
size_t resultIndex = lane - instance.laneStart; size_t resultIndex = lane - instance.laneStart;
@@ -69,42 +114,60 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
return std::nullopt; return std::nullopt;
} }
std::optional<ComputeInstance> getComputeProducerInstance(Value value) { std::optional<ComputeInstance> getComputeProducerInstance(Value value, const ComputeInstance* consumerInstance) {
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value)) if (std::optional<ProducerValueRef> producer = getProducerValueRef(value, consumerInstance))
return producer->instance; return producer->instance;
return std::nullopt; return std::nullopt;
} }
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) { llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op)) if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end()); return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
auto batch = cast<SpatComputeBatch>(instance.op); auto batch = cast<SpatComputeBatch>(instance.op);
if (batch.getNumResults() != 0)
return llvm::SmallVector<Value, 4>(batch.getInputs().begin(), batch.getInputs().end());
assert(batch.getInputs().size() % static_cast<size_t>(batch.getLaneCount()) == 0
&& "resultless compute_batch inputs must be evenly partitioned by lane");
size_t inputsPerLane = batch.getInputs().size() / static_cast<size_t>(batch.getLaneCount());
llvm::SmallVector<Value, 4> inputs; llvm::SmallVector<Value, 4> inputs;
inputs.reserve(instance.laneCount); inputs.reserve(instance.laneCount * inputsPerLane);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
if (!batch.getInputs().empty()) size_t firstInput = static_cast<size_t>(lane) * inputsPerLane;
inputs.push_back(batch.getInputs()[lane]); inputs.append(batch.getInputs().begin() + firstInput, batch.getInputs().begin() + firstInput + inputsPerLane);
}
return inputs; return inputs;
} }
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) { llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance& instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op)) if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end()); return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
auto batch = cast<SpatComputeBatch>(instance.op); auto batch = cast<SpatComputeBatch>(instance.op);
if (batch.getNumResults() != 0)
return llvm::SmallVector<Value, 4>(batch.getWeights().begin(), batch.getWeights().end());
assert(batch.getWeights().size() % static_cast<size_t>(batch.getLaneCount()) == 0
&& "resultless compute_batch weights must be evenly partitioned by lane");
size_t weightsPerLane = batch.getWeights().size() / static_cast<size_t>(batch.getLaneCount());
llvm::SmallVector<Value, 4> weights; llvm::SmallVector<Value, 4> weights;
weights.reserve(instance.laneCount); weights.reserve(instance.laneCount * weightsPerLane);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
weights.push_back(batch.getWeights()[lane]); size_t firstWeight = static_cast<size_t>(lane) * weightsPerLane;
weights.append(batch.getWeights().begin() + firstWeight, batch.getWeights().begin() + firstWeight + weightsPerLane);
}
return weights; return weights;
} }
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance) { llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance& instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op)) if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end()); return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
auto batch = cast<SpatComputeBatch>(instance.op); auto batch = cast<SpatComputeBatch>(instance.op);
if (batch.getNumResults() != 0)
return llvm::SmallVector<Value, 4>(batch.getResults().begin(), batch.getResults().end());
llvm::SmallVector<Value, 4> outputs; llvm::SmallVector<Value, 4> outputs;
outputs.reserve(instance.laneCount); outputs.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
@@ -113,14 +176,14 @@ llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance
return outputs; return outputs;
} }
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance) { llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance& instance) {
llvm::SmallVector<Type, 4> outputTypes; llvm::SmallVector<Type, 4> outputTypes;
for (Value output : getComputeInstanceOutputValues(instance)) for (Value output : getComputeInstanceOutputValues(instance))
outputTypes.push_back(output.getType()); outputTypes.push_back(output.getType());
return outputTypes; return outputTypes;
} }
Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance) { Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op)) if (auto compute = dyn_cast<SpatCompute>(instance.op))
return compute.getBody().front(); return compute.getBody().front();
return cast<SpatComputeBatch>(instance.op).getBody().front(); return cast<SpatComputeBatch>(instance.op).getBody().front();
@@ -26,8 +26,10 @@ size_t getBatchChunkTargetCount(int32_t laneCount);
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex); ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane); ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value); std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value); const ComputeInstance *consumerInstance = nullptr);
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value,
const ComputeInstance *consumerInstance = nullptr);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance); llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance); llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
@@ -268,24 +268,31 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getDeviceTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getHostSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp( auto status = rewriteSubviewCopyLikeOp(
copyOp, copyOp,
copyOp.getDeviceTarget(), copyOp.getDeviceTarget(),
copyOp.getHostSource(), copyOp.getHostSource(),
copyOp.getDeviceTargetOffset(), *dstOffset,
copyOp.getHostSourceOffset(), *srcOffset,
copyOp.getSize(), copyOp.getSize(),
/*allowLoopRewrite=*/true, /*allowLoopRewrite=*/true,
rewriter, rewriter,
[&]( [&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffsetValue = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), dstByteOffset);
Value srcOffsetValue = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), srcByteOffset);
pim::PimMemCopyHostToDevOp::create(rewriter, pim::PimMemCopyHostToDevOp::create(rewriter,
copyOp.getLoc(), copyOp.getLoc(),
resultType, resultType,
dstOffsetValue,
srcOffsetValue,
dst, dst,
src, src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes))); rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
}); });
if (failed(status)) if (failed(status))
@@ -301,24 +308,31 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override { LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
auto dstOffset = resolveIndexValue(copyOp.getHostTargetOffset());
auto srcOffset = resolveIndexValue(copyOp.getDeviceSourceOffset());
if (failed(dstOffset) || failed(srcOffset))
return failure();
auto status = rewriteSubviewCopyLikeOp( auto status = rewriteSubviewCopyLikeOp(
copyOp, copyOp,
copyOp.getHostTarget(), copyOp.getHostTarget(),
copyOp.getDeviceSource(), copyOp.getDeviceSource(),
copyOp.getHostTargetOffset(), *dstOffset,
copyOp.getDeviceSourceOffset(), *srcOffset,
copyOp.getSize(), copyOp.getSize(),
/*allowLoopRewrite=*/false, /*allowLoopRewrite=*/false,
rewriter, rewriter,
[&]( [&](
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
Value dstOffset = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), dstByteOffset);
Value srcOffset = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), srcByteOffset);
pim::PimMemCopyDevToHostOp::create(rewriter, pim::PimMemCopyDevToHostOp::create(rewriter,
copyOp.getLoc(), copyOp.getLoc(),
resultType, resultType,
dstOffset,
srcOffset,
dst, dst,
src, src,
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes))); rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
}); });
if (failed(status)) if (failed(status))
@@ -355,9 +369,7 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
if (failed(staticOffsets)) if (failed(staticOffsets))
return failure(); return failure();
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType(); auto resultMemRefType = cast<MemRefType>(subviewOp.getType());
auto resultMemRefType =
MemRefType::get(SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape()); auto foldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, resultMemRefType.getShape());
if (failed(foldedAttr)) if (failed(foldedAttr))
return failure(); return failure();
@@ -23,11 +23,11 @@ namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op)) if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1; return operandIndex == 3;
if (isa<pim::PimMemCopyHostToDevBatchOp>(op)) if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1; return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op)) if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0; return operandIndex == 2;
return false; return false;
} }
@@ -39,7 +39,10 @@ static int64_t getValueSizeInBytes(Value value) {
} }
template <typename CoreOpTy> template <typename CoreOpTy>
static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) { static void materializeHostConstantsInCore(CoreOpTy coreOp,
IRRewriter& rewriter,
OperationFolder& constantFolder,
bool& hasFailure) {
DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues; DenseMap<Value, DenseMap<int64_t, DenseMap<Type, Value>>> materializedValues;
SmallVector<Operation*> ops; SmallVector<Operation*> ops;
coreOp.getBody().front().walk([&](Operation* op) { coreOp.getBody().front().walk([&](Operation* op) {
@@ -48,6 +51,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
}); });
for (Operation* op : ops) { for (Operation* op : ops) {
if (auto loadOp = dyn_cast<memref::LoadOp>(op); loadOp && loadOp.getType().isIndex())
continue;
for (OpOperand& operand : op->getOpOperands()) { for (OpOperand& operand : op->getOpOperands()) {
Value originalValue = operand.get(); Value originalValue = operand.get();
if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber())) if (!isa<BaseMemRefType>(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber()))
@@ -105,14 +111,15 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter
.getOutput(); .getOutput();
} }
else { else {
copiedValue = pim::PimMemCopyHostToDevOp::create( copiedValue =
pim::PimMemCopyHostToDevOp::create(
rewriter, rewriter,
op->getLoc(), op->getLoc(),
originalType, originalType,
getOrCreateHostIndexConstant(op, 0, constantFolder),
getOrCreateHostIndexConstant(op, static_cast<int64_t>(resolvedAddress->byteOffset), constantFolder),
deviceDst, deviceDst,
getGlobalOp.getResult(), getGlobalOp.getResult(),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(resolvedAddress->byteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes))) rewriter.getI32IntegerAttr(static_cast<int32_t>(totalBytes)))
.getOutput(); .getOutput();
} }
@@ -134,6 +141,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
void runOnOperation() override { void runOnOperation() override {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
IRRewriter rewriter(moduleOp.getContext()); IRRewriter rewriter(moduleOp.getContext());
OperationFolder constantFolder(moduleOp.getContext());
bool hasFailure = false; bool hasFailure = false;
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
@@ -141,10 +149,10 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
continue; continue;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>())
materializeHostConstantsInCore(coreOp, rewriter, hasFailure); materializeHostConstantsInCore(coreOp, rewriter, constantFolder, hasFailure);
for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>()) for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps<pim::PimCoreBatchOp>())
materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure); materializeHostConstantsInCore(coreBatchOp, rewriter, constantFolder, hasFailure);
SmallVector<Operation*> hostCompactOps; SmallVector<Operation*> hostCompactOps;
for (Operation& op : funcOp.getBody().front()) for (Operation& op : funcOp.getBody().front())
+45 -3
View File
@@ -8,6 +8,7 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -119,11 +120,27 @@ static bool isConstantGlobalView(Value value) {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyHostToDevOp>(op)) if (isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 1; return operandIndex == 3;
if (isa<pim::PimMemCopyHostToDevBatchOp>(op)) if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1; return operandIndex == 1;
if (isa<pim::PimMemCopyDevToHostOp>(op)) if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 0; return operandIndex == 2;
return false;
}
static bool isCoreWeightBlockArgument(Value value) {
auto blockArgument = dyn_cast<BlockArgument>(value);
if (!blockArgument)
return false;
if (auto coreOp = dyn_cast<pim::PimCoreOp>(blockArgument.getOwner()->getParentOp()))
return static_cast<unsigned>(blockArgument.getArgNumber()) < coreOp.getWeights().size();
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(blockArgument.getOwner()->getParentOp())) {
unsigned argNumber = static_cast<unsigned>(blockArgument.getArgNumber());
return argNumber > 0 && argNumber <= coreBatchOp.getWeights().size();
}
return false; return false;
} }
@@ -193,7 +210,9 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) { if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics); (void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
(void) verifyCoreOperands(coreBatchOp, diagnostics); for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
(void) withScalarCoreFromBatchLane(
coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); });
continue; continue;
} }
@@ -297,6 +316,9 @@ private:
if (!isa<BaseMemRefType>(operand.getType())) if (!isa<BaseMemRefType>(operand.getType()))
continue; continue;
if (isCoreWeightBlockArgument(operand))
continue;
auto resolvedAddress = resolveContiguousAddress(operand, knowledge); auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
if (failed(resolvedAddress)) { if (failed(resolvedAddress)) {
diagnostics.report(&op, [&](Operation* illegalOp) { diagnostics.report(&op, [&](Operation* illegalOp) {
@@ -327,6 +349,26 @@ private:
hasFailure = true; hasFailure = true;
} }
} }
if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op)) {
if (failed(resolveIndexValue(storeOp.getHostTargetOffset(), knowledge))
|| failed(resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge))) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
});
hasFailure = true;
}
}
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op)) {
if (failed(resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge))
|| failed(resolveIndexValue(loadOp.getHostSourceOffset(), knowledge))) {
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("offset operands must be statically evaluable for PIM codegen");
});
hasFailure = true;
}
}
return success(!hasFailure); return success(!hasFailure);
}); });
} }
+1 -1
View File
@@ -67,7 +67,7 @@ def main():
help="Core count to pass to Raptor. Required for PIM validation.") help="Core count to pass to Raptor. Required for PIM validation.")
ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft", ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft",
help="Scheduler used by the Spatial merge-compute-nodes pass.") help="Scheduler used by the Spatial merge-compute-nodes pass.")
ap.add_argument("--command-timeout-seconds", type=float, default=6000000000000000.0, ap.add_argument("--command-timeout-seconds", type=float, default=1000000.0,
help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.") help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.")
ap.add_argument("--clean", action="store_true", ap.add_argument("--clean", action="store_true",
help="Remove generated validation artifacts under each model workspace and exit.") help="Remove generated validation artifacts under each model workspace and exit.")