compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
This commit is contained in:
9
.gitignore
vendored
9
.gitignore
vendored
@@ -1,5 +1,14 @@
|
||||
.zed
|
||||
.idea
|
||||
**/.vscode
|
||||
|
||||
.claude
|
||||
AGENTS.md
|
||||
|
||||
CMakeUserPresets.json
|
||||
|
||||
build
|
||||
cmake-build-debug
|
||||
cmake-build-release
|
||||
|
||||
**/__pycache__
|
||||
|
||||
@@ -135,7 +135,7 @@ validate.py \
|
||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir ../onnx-mlir/include \
|
||||
--operations-dir ./networks/yolo11n/depth_04 \
|
||||
--crossbar-size 2048
|
||||
--crossbar-size 2048 --crossbar-count 256
|
||||
```
|
||||
|
||||
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||
|
||||
@@ -48,7 +48,9 @@ void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
os << *moduleOp;
|
||||
OpPrintingFlags flags;
|
||||
flags.elideLargeElementsAttrs();
|
||||
moduleOp.print(os, flags);
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
@@ -173,6 +175,13 @@ void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> cal
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback);
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
});
|
||||
}
|
||||
|
||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||
@@ -181,66 +190,6 @@ memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp
|
||||
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
}
|
||||
|
||||
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
// channelNewOp should have two users: `op` and a
|
||||
// `ChannelSendOp`/`ChannelReceiveOp`
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
auto secondUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"more than two found.");
|
||||
return failure();
|
||||
}
|
||||
Operation* notOpUser;
|
||||
if (firstUser == op) {
|
||||
notOpUser = secondUser;
|
||||
}
|
||||
else if (secondUser == op) {
|
||||
notOpUser = firstUser;
|
||||
}
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"and one of them must be me, but"
|
||||
"none of them is actually me.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive) {
|
||||
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
else {
|
||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
|
||||
@@ -17,6 +17,8 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id";
|
||||
|
||||
struct ResolvedContiguousAddress {
|
||||
mlir::Value base;
|
||||
int64_t byteOffset = 0;
|
||||
@@ -48,9 +50,6 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
llvm::FailureOr<mlir::Operation*>
|
||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
@@ -474,19 +476,113 @@ std::string getMemorySizeAsString(size_t size) {
|
||||
return std::to_string(size) + " Bytes";
|
||||
}
|
||||
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||
return getUsedWeightIndices(coreOp.getBody().front());
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
assert(coreIdsAttr && "pim.core_batch requires core_id array attribute");
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
SmallVector<Operation*> coreLikeOps;
|
||||
for (Operation& op : funcOp.getBody().front()) {
|
||||
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
||||
coreLikeOps.push_back(&op);
|
||||
}
|
||||
return coreLikeOps;
|
||||
}
|
||||
|
||||
static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
|
||||
OpBuilder builder(coreBatchOp);
|
||||
builder.setInsertionPointAfter(coreBatchOp);
|
||||
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
||||
SmallVector<mlir::Value> laneWeights;
|
||||
laneWeights.reserve(weightsPerLane);
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
||||
|
||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||
auto scalarCore = pim::PimCoreOp::create(builder,
|
||||
coreBatchOp.getLoc(),
|
||||
ValueRange(laneWeights),
|
||||
builder.getI32IntegerAttr(coreIds[lane]));
|
||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||
IRMapping mapper;
|
||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||
if (isa<pim::PimHaltOp>(op)) {
|
||||
pim::PimHaltOp::create(builder, op.getLoc());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(builder,
|
||||
sendBatchOp.getLoc(),
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive = pim::PimReceiveOp::create(builder,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
|
||||
if (!hostSource)
|
||||
hostSource = memcpBatchOp.getHostSource();
|
||||
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||
hostSource,
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
|
||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||
return scalarCore;
|
||||
}
|
||||
|
||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||
static OnnxMlirCompilerErrorCodes
|
||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||
@@ -670,7 +766,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>>
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||
@@ -679,10 +775,24 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
size_t indexFileName = 0;
|
||||
|
||||
int64_t xbarSize = crossbarSize.getValue();
|
||||
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
SmallVector<pim::PimCoreOp> scalarCores;
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
scalarCores.push_back(coreOp);
|
||||
}
|
||||
else {
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
|
||||
}
|
||||
|
||||
for (pim::PimCoreOp coreOp : scalarCores) {
|
||||
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
@@ -717,7 +827,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
if (mapGlobalOpToFileName.contains(globalOp)) {
|
||||
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
|
||||
mapCoreWeightToFileName[coreOp].insert(weightToFile);
|
||||
mapCoreWeightToFileName[coreId].insert(weightToFile);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -756,9 +866,14 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
|
||||
weightFileStream.close();
|
||||
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||
mapCoreWeightToFileName[coreOp].insert({weight, newFileName});
|
||||
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
||||
}
|
||||
}
|
||||
|
||||
for (pim::PimCoreOp coreOp : scalarCores)
|
||||
if (coreOp.getOperation() != op)
|
||||
coreOp.erase();
|
||||
}
|
||||
return mapCoreWeightToFileName;
|
||||
}
|
||||
|
||||
@@ -850,7 +965,20 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
// Create Weight Folder
|
||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||
|
||||
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
SmallVector<pim::PimCoreOp> scalarCores;
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
scalarCores.push_back(coreOp);
|
||||
}
|
||||
else {
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
|
||||
}
|
||||
|
||||
for (pim::PimCoreOp coreOp : scalarCores) {
|
||||
auto coreId = coreOp.getCoreId();
|
||||
coreCount++;
|
||||
|
||||
@@ -871,19 +999,17 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
return CompilerFailure;
|
||||
assert(processedOperations > 0);
|
||||
|
||||
// Remove trailing comma, close JSON array
|
||||
coreFileStream.seek(coreFileStream.tell() - 1);
|
||||
coreFileStream << ']';
|
||||
coreFileStream.close();
|
||||
|
||||
// Write crossbar weights for this core
|
||||
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
||||
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
||||
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
|
||||
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
||||
auto& mapWeightToFile = mapCoreWeightToFileName[static_cast<size_t>(coreId)];
|
||||
json::Array xbarsPerGroup;
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
@@ -897,8 +1023,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
||||
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
|
||||
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
||||
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message()
|
||||
<< '\n';
|
||||
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:"
|
||||
<< error.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
}
|
||||
@@ -906,5 +1032,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
||||
}
|
||||
|
||||
for (pim::PimCoreOp coreOp : scalarCores)
|
||||
if (coreOp.getOperation() != op)
|
||||
coreOp.erase();
|
||||
}
|
||||
|
||||
return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -174,6 +175,31 @@ using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename RewriterT>
|
||||
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
||||
assert(!inputs.empty() && "spat.concat requires at least one input");
|
||||
if (inputs.size() == 1)
|
||||
return inputs.front();
|
||||
|
||||
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
||||
auto outputShape = llvm::to_vector(firstType.getShape());
|
||||
int64_t concatDimSize = 0;
|
||||
bool concatDimDynamic = false;
|
||||
|
||||
for (mlir::Value input : inputs) {
|
||||
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
||||
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
||||
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
||||
concatDimDynamic = true;
|
||||
else
|
||||
concatDimSize += inputType.getDimSize(axis);
|
||||
}
|
||||
|
||||
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
||||
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||
}
|
||||
|
||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
|
||||
@@ -54,6 +54,43 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<spatial::SpatComputeBatch> batchOps;
|
||||
funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); });
|
||||
|
||||
for (auto batchOp : batchOps) {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
continue;
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp.replaceAllUsesWith(computeOp.getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
@@ -124,6 +161,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
foldSingleLaneComputeBatches(*entryFunc);
|
||||
|
||||
// Count the number of compute ops and check they do not exceed the core count
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
@@ -196,8 +235,12 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
IRMapping mapper;
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||
auto newConcat = spatial::SpatConcatOp::create(rewriter,
|
||||
loc,
|
||||
toRemoveOp.getType(),
|
||||
rewriter.getI64IntegerAttr(toRemoveOp.getDim()),
|
||||
ValueRange(BB->getArguments()));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
|
||||
@@ -147,11 +147,11 @@ static Value buildPackedBias(bool hasBias,
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
}
|
||||
|
||||
static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||
static Value createIm2colRowComputes(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType im2colRowType,
|
||||
RankedTensorType gemmInputRowType,
|
||||
RankedTensorType gemmInputRowsType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
@@ -176,8 +176,8 @@ static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||
auto elemType = xType.getElementType();
|
||||
constexpr size_t numInputs = 1;
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
|
||||
auto im2colComputeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
@@ -285,23 +285,10 @@ static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||
});
|
||||
}
|
||||
|
||||
SmallVector<Value> rowResults;
|
||||
rowResults.reserve(packedNumRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
rowResults.push_back(
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||
});
|
||||
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(im2colComputeOp.getNumResults());
|
||||
for (Value result : im2colComputeOp.getResults())
|
||||
rows.push_back(result);
|
||||
return rows;
|
||||
return im2colComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
@@ -319,15 +306,12 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||
Value gemmOut;
|
||||
if (packFactor == 1) {
|
||||
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
}
|
||||
else {
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
Value packedOutput = gemmRowArgs.size() == 1
|
||||
? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
@@ -509,14 +493,15 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||
// B_packed: [N * patchSize, N * cOut]
|
||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||
auto gemmOutputRowType =
|
||||
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||
auto gemmOutputRowsType =
|
||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
Value gemmInputRows = createIm2colRowComputes(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
gemmInputRowType,
|
||||
gemmInputRowsType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
@@ -553,13 +538,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
Value gemmC = buildPackedBias(
|
||||
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(gemmInputRows.size());
|
||||
for (Value gemmInputRow : gemmInputRows) {
|
||||
Value gemmRow = ONNXGemmOp::create(rewriter,
|
||||
Value gemmRows = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutputRowType,
|
||||
gemmInputRow,
|
||||
gemmOutputRowsType,
|
||||
gemmInputRows,
|
||||
gemmB,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
@@ -567,11 +549,9 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
gemmRows.push_back(gemmRow);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(gemmRows,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -65,6 +66,66 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
RankedTensorType matrixType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numRows = matrixType.getDimSize(0);
|
||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||
|
||||
auto buildRowSlices = [&](Value matrixArg) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
};
|
||||
|
||||
auto cloneBatchInputChainIntoSliceCompute =
|
||||
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
|
||||
auto sliceCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
|
||||
Value transformedMatrix = input;
|
||||
if (!chainOps.empty()) {
|
||||
IRMapping mapper;
|
||||
mapper.map(rootValue, input);
|
||||
for (Operation* chainOp : chainOps)
|
||||
rewriter.clone(*chainOp, mapper);
|
||||
transformedMatrix = cast<Value>(mapper.lookup(matrix));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
|
||||
});
|
||||
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
|
||||
return rowSlices;
|
||||
};
|
||||
|
||||
SmallVector<Operation*> chainOps;
|
||||
Value rootValue = matrix;
|
||||
while (Operation* definingOp = rootValue.getDefiningOp()) {
|
||||
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
|
||||
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||
return cloneBatchInputChainIntoSliceCompute(
|
||||
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
|
||||
}
|
||||
|
||||
if (definingOp->getNumOperands() != 1)
|
||||
break;
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||
break;
|
||||
|
||||
chainOps.push_back(definingOp);
|
||||
rootValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
return buildRowSlices(matrix);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
@@ -156,8 +217,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
@@ -313,8 +373,108 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
auto concatComputeOp =
|
||||
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = gemmOpAdaptor.getA();
|
||||
Value b = gemmOpAdaptor.getB();
|
||||
Value c = gemmOpAdaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto bType = cast<RankedTensorType>(b.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape());
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
|
||||
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|
||||
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
|
||||
return failure();
|
||||
|
||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
auto cType = cast<RankedTensorType>(c.getType());
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
||||
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
||||
return failure();
|
||||
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
||||
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
|
||||
sharedBias = c;
|
||||
}
|
||||
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||
ValueRange(weights),
|
||||
ValueRange(aSlices));
|
||||
|
||||
Block* body = rewriter.createBlock(
|
||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
||||
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
@@ -322,6 +482,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
@@ -232,9 +232,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}));
|
||||
}
|
||||
|
||||
Value result = batchResults.size() == 1
|
||||
? batchResults.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, batchResults).getResult();
|
||||
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -100,8 +100,7 @@ static Value buildReduceMeanKeepdims(Value input,
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return reducedSlices.size() == 1 ? reducedSlices.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
|
||||
return createSpatConcat(rewriter, loc, axis, reducedSlices);
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
|
||||
@@ -33,9 +33,7 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
||||
|
||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
||||
return createSpatConcat(rewriter, loc, axis, values);
|
||||
}
|
||||
|
||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
|
||||
@@ -47,8 +47,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
|
||||
return createSpatConcat(rewriter, loc, axis, rebuiltSlices);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -17,7 +18,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
auto inputs = adaptor.getInputs();
|
||||
int64_t axis = adaptor.getAxis();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
|
||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
|
||||
}
|
||||
if (slices.empty())
|
||||
return {};
|
||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
||||
return createSpatConcat(rewriter, loc, axis, slices);
|
||||
}
|
||||
|
||||
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -130,9 +130,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
||||
return failure();
|
||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
||||
}
|
||||
result = rows.size() == 1
|
||||
? rows.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
|
||||
}
|
||||
else {
|
||||
return failure();
|
||||
|
||||
@@ -50,7 +50,7 @@ static Value buildNearestResize(Value input,
|
||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||
}
|
||||
|
||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
||||
return createSpatConcat(rewriter, loc, axis, slices);
|
||||
}
|
||||
|
||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
|
||||
@@ -7,23 +7,12 @@
|
||||
#include <cstddef>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
|
||||
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
|
||||
assert(attr && "required precomputed channel attr is missing");
|
||||
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||
/*
|
||||
EXAMPLE RUN:
|
||||
@@ -74,37 +63,6 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||
}
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
mlir::Value
|
||||
createPimReceiveFromSpatialChannel(PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||
auto users = value.getUsers();
|
||||
|
||||
|
||||
@@ -2,16 +2,10 @@
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
|
||||
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
@@ -30,17 +24,6 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||
|
||||
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||
|
||||
template <class T>
|
||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||
return std::distance(range.begin(), range.end());
|
||||
|
||||
@@ -9,17 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed source core id">;
|
||||
|
||||
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||
"spatial channel has precomputed target core id">;
|
||||
|
||||
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||
|
||||
def onnxToPimTranspose : Pat<
|
||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||
(PimTransposeOp $data, $perms,
|
||||
@@ -80,18 +69,4 @@ def spatToPimVSoftmax : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatChannelSendToPimSend : Pat<
|
||||
(SpatChannelSendOp $channel, $input),
|
||||
(PimSendOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
def spatChannelReceiveToPimReceive : Pat<
|
||||
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -39,6 +39,22 @@ def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Execute equivalent batched core bodies";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$laneCount,
|
||||
Variadic<PimTensor>:$weights,
|
||||
Variadic<PimTensor>:$inputs
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`lanes` $laneCount `(` $weights `)` `[` $inputs `]` attr-dict regions `:` type($weights) `[` type($inputs) `]` `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def PimHaltOp : PimOp<"halt", [Terminator]> {
|
||||
let summary = "Halt execution of the core";
|
||||
|
||||
@@ -65,6 +81,20 @@ def PimSendOp : PimOp<"send", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSendBatchOp : PimOp<"send_batch", []> {
|
||||
let summary = "Send a per-lane tensor to target cores from a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
I32Attr:$size,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive a tensor from another core";
|
||||
|
||||
@@ -89,6 +119,30 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive per-lane tensors from source cores into a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
I32Attr:$size,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
@@ -115,6 +169,32 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$deviceTarget,
|
||||
PimTensor:$hostSource,
|
||||
I32Attr:$deviceTargetOffset,
|
||||
I32Attr:$hostSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDeviceTargetMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from device memory into host memory";
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
@@ -65,6 +66,32 @@ struct MemCopyHostToDevOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyHostToDevBatchOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevBatchOpInterface, PimMemCopyHostToDevBatchOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
|
||||
auto deviceTargetOpt = getBuffer(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
|
||||
if (failed(deviceTargetOpt))
|
||||
return failure();
|
||||
auto hostSourceOpt = getBuffer(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
|
||||
if (failed(hostSourceOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevBatchOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceTargetOpt->getType(),
|
||||
*deviceTargetOpt,
|
||||
*hostSourceOpt,
|
||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyDevToHostOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -122,6 +149,127 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveBatchOp>(op);
|
||||
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return {};
|
||||
|
||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
||||
return {{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}};
|
||||
}
|
||||
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
FailureOr<BufferLikeType>
|
||||
getBufferType(Operation* op,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
const BufferizationState& state,
|
||||
SmallVector<Value>& invocationStack) const {
|
||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return failure();
|
||||
|
||||
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
|
||||
return memRefType;
|
||||
|
||||
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||
|
||||
SmallVector<Value> weights;
|
||||
SmallVector<Value> inputs;
|
||||
weights.reserve(coreBatchOp.getWeights().size());
|
||||
inputs.reserve(coreBatchOp.getInputs().size());
|
||||
|
||||
for (Value weight : coreBatchOp.getWeights()) {
|
||||
if (isa<TensorType>(weight.getType())) {
|
||||
auto weightOpt = getBuffer(rewriter, weight, options, state);
|
||||
if (failed(weightOpt))
|
||||
return failure();
|
||||
weights.push_back(*weightOpt);
|
||||
}
|
||||
else {
|
||||
weights.push_back(weight);
|
||||
}
|
||||
}
|
||||
|
||||
for (Value input : coreBatchOp.getInputs()) {
|
||||
if (isa<TensorType>(input.getType())) {
|
||||
auto inputOpt = getBuffer(rewriter, input, options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
inputs.push_back(*inputOpt);
|
||||
}
|
||||
else {
|
||||
inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(coreBatchOp);
|
||||
auto newOp = PimCoreBatchOp::create(
|
||||
rewriter, coreBatchOp.getLoc(), coreBatchOp.getLaneCountAttr(), ValueRange(weights), ValueRange(inputs));
|
||||
newOp.getProperties().setOperandSegmentSizes({static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||
if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName))
|
||||
newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr);
|
||||
|
||||
rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin());
|
||||
for (Block& block : newOp.getBody())
|
||||
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
|
||||
return failure();
|
||||
|
||||
rewriter.eraseOp(coreBatchOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -287,8 +435,11 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
|
||||
@@ -93,8 +93,8 @@ void PimBufferizationPass::runOnOperation() {
|
||||
}
|
||||
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](PimCoreOp coreOp) {
|
||||
walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) {
|
||||
auto markWeights = [&](Operation* op) {
|
||||
walkPimMvmVmmWeightUses(op, [&](OpOperand& weightUse) {
|
||||
Value weight = weightUse.get();
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
@@ -104,7 +104,10 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
||||
markWeightAlways(getGlobalOp);
|
||||
markWeightAlways(globalMemrefOp);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
funcOp.walk([&](PimCoreOp coreOp) { markWeights(coreOp); });
|
||||
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
|
||||
@@ -2,6 +2,7 @@ add_onnx_mlir_dialect(Spatial spat)
|
||||
add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
add_pim_library(SpatialOps
|
||||
Channels.cpp
|
||||
SpatialOps.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
|
||||
120
src/PIM/Dialect/Spatial/Channels.cpp
Normal file
120
src/PIM/Dialect/Spatial/Channels.cpp
Normal file
@@ -0,0 +1,120 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
|
||||
|
||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
||||
if (!endpoints.send || !endpoints.receive)
|
||||
return failure();
|
||||
|
||||
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
|
||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
|
||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
|
||||
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Channels::Channels(func::FuncOp funcOp) {
|
||||
if (!funcOp)
|
||||
return;
|
||||
|
||||
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
|
||||
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
|
||||
}
|
||||
|
||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
||||
|
||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].send = sendOp;
|
||||
}
|
||||
|
||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].receive = receiveOp;
|
||||
}
|
||||
|
||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.send = {};
|
||||
if (!it->second.receive)
|
||||
endpoints.erase(it);
|
||||
}
|
||||
|
||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.receive = {};
|
||||
if (!it->second.send)
|
||||
endpoints.erase(it);
|
||||
}
|
||||
|
||||
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
||||
auto it = endpoints.find(id);
|
||||
if (it == endpoints.end())
|
||||
return failure();
|
||||
return it->second;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(sendOp));
|
||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
||||
return failure();
|
||||
return endpointsOr->receive;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(receiveOp));
|
||||
if (failed(endpointsOr) || !endpointsOr->send)
|
||||
return failure();
|
||||
return endpointsOr->send;
|
||||
}
|
||||
|
||||
LogicalResult Channels::verify() const {
|
||||
for (const auto& [channelId, pair] : endpoints) {
|
||||
if (!pair.send || !pair.receive) {
|
||||
if (pair.send) {
|
||||
auto sendOp = pair.send;
|
||||
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
|
||||
}
|
||||
else if (pair.receive) {
|
||||
auto receiveOp = pair.receive;
|
||||
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (failed(verifyEndpointPair(pair)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
43
src/PIM/Dialect/Spatial/Channels.hpp
Normal file
43
src/PIM/Dialect/Spatial/Channels.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
struct ChannelEndpoints {
|
||||
SpatChannelSendOp send;
|
||||
SpatChannelReceiveOp receive;
|
||||
};
|
||||
|
||||
class Channels {
|
||||
public:
|
||||
using ChannelId = int64_t;
|
||||
|
||||
explicit Channels(mlir::func::FuncOp funcOp);
|
||||
|
||||
ChannelId allocate();
|
||||
|
||||
void insertSend(SpatChannelSendOp sendOp);
|
||||
void insertReceive(SpatChannelReceiveOp receiveOp);
|
||||
void eraseSend(SpatChannelSendOp sendOp);
|
||||
void eraseReceive(SpatChannelReceiveOp receiveOp);
|
||||
|
||||
llvm::FailureOr<ChannelEndpoints> lookup(ChannelId id) const;
|
||||
llvm::FailureOr<SpatChannelReceiveOp> getReceiveFor(SpatChannelSendOp sendOp) const;
|
||||
llvm::FailureOr<SpatChannelSendOp> getSendFor(SpatChannelReceiveOp receiveOp) const;
|
||||
|
||||
mlir::LogicalResult verify() const;
|
||||
|
||||
private:
|
||||
ChannelId nextChannelId = 0;
|
||||
llvm::DenseMap<ChannelId, ChannelEndpoints> endpoints;
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -9,7 +9,6 @@ def SpatialDialect : Dialect {
|
||||
let name = "spat";
|
||||
let summary = "Dialect designed for deep learning computation in a spatial architecture";
|
||||
let cppNamespace = "::onnx_mlir::spatial";
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
}
|
||||
|
||||
class SpatOp<string mnemonic, list<Trait> traits = []> :
|
||||
@@ -19,15 +18,6 @@ class SpatOp<string mnemonic, list<Trait> traits = []> :
|
||||
def SpatTensor :
|
||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||
|
||||
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
|
||||
: TypeDef<SpatialDialect, name, traits> {
|
||||
let mnemonic = typeMnemonic;
|
||||
}
|
||||
|
||||
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
||||
let summary = "Virtual channel type";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -48,10 +38,27 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
let assemblyFormat = [{
|
||||
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body
|
||||
}];
|
||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
[SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compressed batch of independent equivalent compute lanes";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$laneCount,
|
||||
Variadic<SpatTensor>:$weights,
|
||||
Variadic<SpatTensor>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||
@@ -61,51 +68,66 @@ def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$outputs attr-dict `:` type($outputs)
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatExtractRowsOp : SpatOp<"extract_rows", []> {
|
||||
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatConcatOp : SpatOp<"concat", []> {
|
||||
let summary = "Concatenate tensors with compact Spatial operand syntax";
|
||||
|
||||
let arguments = (ins
|
||||
I64Attr:$axis,
|
||||
Variadic<SpatTensor>:$inputs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
||||
let summary = "Create a new virtual channel";
|
||||
|
||||
let results = (outs
|
||||
SpatChannelType:$channel
|
||||
);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins ), [{
|
||||
$_state.addTypes(SpatChannelType());
|
||||
}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||
let summary = "Send a tensor through a channel";
|
||||
let summary = "Send a tensor through a logical channel";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType:$channel,
|
||||
I64Attr:$channelId,
|
||||
I32Attr:$sourceCoreId,
|
||||
I32Attr:$targetCoreId,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||
$input attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
let summary = "Receive a tensor from a channel";
|
||||
let summary = "Receive a tensor from a logical channel";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType:$channel
|
||||
I64Attr:$channelId,
|
||||
I32Attr:$sourceCoreId,
|
||||
I32Attr:$targetCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -113,37 +135,70 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||
attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
||||
let summary = "Broadcast a tensor through a shared channel buffer";
|
||||
def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> {
|
||||
let summary = "Send multiple tensors through logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType:$channel,
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<SpatTensor>:$inputs
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
|
||||
let summary = "Receive multiple tensors from logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<SpatTensor>:$outputs
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
let summary = "Send per-lane tensors through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
||||
let summary = "Receive a tensor from a shared channel buffer";
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
SpatChannelType:$channel
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -28,6 +28,8 @@ namespace spatial {
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
||||
|
||||
struct VirtualNode {
|
||||
SmallVector<size_t, 4> originalComputeIndices;
|
||||
@@ -54,6 +56,45 @@ struct WindowScheduleResult {
|
||||
size_t maxMergeGroupSize = 0;
|
||||
};
|
||||
|
||||
constexpr CPU kDefaultMaxCpuCount = 1000;
|
||||
|
||||
size_t getSchedulingCpuBudget() {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<size_t>(coresCount.getValue());
|
||||
return static_cast<size_t>(kDefaultMaxCpuCount);
|
||||
}
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
|
||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||
|
||||
size_t chunkIndex = 0;
|
||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||
else
|
||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||
return getBatchChunkForIndex(batch, chunkIndex);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
@@ -81,14 +122,96 @@ std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges) {
|
||||
Weight getComputeBodyWeight(Region& body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : body)
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& block : body)
|
||||
for (auto& op : block)
|
||||
if (isa<SpatWeightedVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeWeight(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
inputs.push_back(batch.getInputs()[lane]);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
||||
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
||||
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
||||
SmallVector<ComputeInstance> instances;
|
||||
for (Region& region : entryOp->getRegions()) {
|
||||
for (Block& block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return instances;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph graph;
|
||||
graph.nodes.reserve(spatComputes.size());
|
||||
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||
graph.nodes.reserve(computeInstances.size());
|
||||
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
VirtualNode node;
|
||||
node.originalComputeIndices.push_back(index);
|
||||
node.weight = getSpatComputeWeight(spatCompute);
|
||||
node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute);
|
||||
node.weight = getComputeInstanceWeight(computeInstance);
|
||||
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
|
||||
graph.nodes.push_back(std::move(node));
|
||||
}
|
||||
graph.edges = aggregateEdges(edges);
|
||||
@@ -116,22 +239,34 @@ TimingInfo computeTiming(const VirtualGraph& graph) {
|
||||
incomingEdgeCount[endIndex]++;
|
||||
}
|
||||
|
||||
std::vector<size_t> readyNodes;
|
||||
readyNodes.reserve(nodeCount);
|
||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||
if (!node.originalComputeIndices.empty())
|
||||
return node.originalComputeIndices.front();
|
||||
return nodeIndex;
|
||||
};
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||
if (lhsKey != rhsKey)
|
||||
return lhsKey > rhsKey;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
if (incomingEdgeCount[i] == 0)
|
||||
readyNodes.push_back(i);
|
||||
readyNodes.push(i);
|
||||
|
||||
size_t readyIndex = 0;
|
||||
while (readyIndex != readyNodes.size()) {
|
||||
size_t current = readyNodes[readyIndex++];
|
||||
while (!readyNodes.empty()) {
|
||||
size_t current = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
timing.topologicalOrder.push_back(current);
|
||||
for (auto [child, weight] : children[current]) {
|
||||
(void) weight;
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push_back(child);
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,17 +422,21 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> windowNodeOrderKeys;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
windowWeights.reserve(selectedNodes.size());
|
||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||
|
||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||
}
|
||||
|
||||
GraphDCP windowGraph(windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowCrossbarUsage);
|
||||
GraphDCP windowGraph(
|
||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
windowGraph.setContext(context);
|
||||
@@ -414,13 +553,7 @@ bool coarsenGraph(const VirtualGraph& graph,
|
||||
return true;
|
||||
}
|
||||
|
||||
constexpr CPU kDefaultMaxCpuCount = 1000;
|
||||
|
||||
CPU getVirtualGraphMaxCpuCount() {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<CPU>(coresCount.getValue());
|
||||
return kDefaultMaxCpuCount;
|
||||
}
|
||||
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
|
||||
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
||||
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
|
||||
@@ -430,7 +563,7 @@ size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<SpatCompute> spatComputes) {
|
||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
|
||||
DCPAnalysisResult result;
|
||||
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
@@ -443,19 +576,19 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRe
|
||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<size_t> originalComputeToCpu(spatComputes.size(), 0);
|
||||
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
|
||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||
originalComputeToCpu[originalIndex] = cpu;
|
||||
}
|
||||
|
||||
result.dominanceOrderCompute.reserve(spatComputes.size());
|
||||
for (auto [originalIndex, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
size_t cpu = originalComputeToCpu[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(spatCompute);
|
||||
result.computeToCpuMap[spatCompute] = cpu;
|
||||
result.cpuToLastComputeMap[cpu] = spatCompute;
|
||||
result.dominanceOrderCompute.push_back(computeInstance);
|
||||
result.computeToCpuMap[computeInstance] = cpu;
|
||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
||||
}
|
||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
@@ -463,13 +596,44 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRe
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
||||
GraphDCP graphDCP(spatComputes, edges);
|
||||
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
|
||||
DCPAnalysisResult result;
|
||||
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
|
||||
|
||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.empty())
|
||||
continue;
|
||||
|
||||
for (const auto& task : scheduledTasks)
|
||||
result.computeToCpuMap[computeInstances[task.nodeIndex]] = cpu;
|
||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult
|
||||
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
||||
SmallVector<Weight> nodeWeights;
|
||||
SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||
SmallVector<int64_t> nodeOrderKeys;
|
||||
nodeWeights.reserve(computeInstances.size());
|
||||
nodeCrossbarUsage.reserve(computeInstances.size());
|
||||
nodeOrderKeys.reserve(computeInstances.size());
|
||||
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
|
||||
nodeWeights.push_back(getComputeInstanceWeight(instance));
|
||||
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
|
||||
nodeOrderKeys.push_back(static_cast<int64_t>(index));
|
||||
}
|
||||
|
||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
graphDCP.setContext(context);
|
||||
graphDCP.runDcp();
|
||||
return graphDCP.getResult();
|
||||
return buildResultFromScheduledGraph(graphDCP, computeInstances);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -488,27 +652,31 @@ SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::run() {
|
||||
SmallVector<SpatCompute, 10> spatComputes;
|
||||
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
|
||||
SmallVector<IndexedEdge, 10> edges;
|
||||
for (auto& region : entryOp->getRegions())
|
||||
for (SpatCompute spatCompute : region.getOps<SpatCompute>())
|
||||
spatComputes.push_back(spatCompute);
|
||||
|
||||
for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||
for (Value input : spatCompute.getInputs()) {
|
||||
if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) {
|
||||
auto producerIt = llvm::find(spatComputes, producerCompute);
|
||||
assert(producerIt != spatComputes.end());
|
||||
auto indexStartEdge = std::distance(spatComputes.begin(), producerIt);
|
||||
edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast<ShapedType>(input.getType()))});
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
instanceToIndex.reserve(computeInstances.size());
|
||||
for (auto [index, instance] : llvm::enumerate(computeInstances))
|
||||
instanceToIndex[instance] = index;
|
||||
|
||||
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
for (Value input : getComputeInstanceInputs(computeInstance)) {
|
||||
if (auto producerInstance = getOriginalComputeInstance(input)) {
|
||||
auto producerIt = instanceToIndex.find(*producerInstance);
|
||||
assert(producerIt != instanceToIndex.end());
|
||||
auto indexStartEdge = producerIt->second;
|
||||
edges.push_back({static_cast<int64_t>(indexStartEdge),
|
||||
static_cast<int64_t>(indexEndEdge),
|
||||
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dcpCriticalWindowSize.getValue() == 0)
|
||||
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
||||
size_t iteration = 0;
|
||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
@@ -545,6 +713,13 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
};
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
||||
if (virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
@@ -576,7 +751,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, spatComputes);
|
||||
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -5,15 +5,28 @@
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
// A scheduling identity that covers both spat.compute and scheduled shards of
|
||||
// spat.compute_batch.
|
||||
struct ComputeInstance {
|
||||
mlir::Operation* op = nullptr;
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 1;
|
||||
|
||||
bool operator==(const ComputeInstance& other) const {
|
||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||
}
|
||||
};
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<onnx_mlir::spatial::SpatCompute> dominanceOrderCompute;
|
||||
llvm::DenseMap<onnx_mlir::spatial::SpatCompute, size_t> computeToCpuMap;
|
||||
llvm::DenseSet<onnx_mlir::spatial::SpatCompute> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatCompute> cpuToLastComputeMap;
|
||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
};
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -34,3 +47,21 @@ public:
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<ComputeInstance> {
|
||||
static ComputeInstance getEmptyKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const ComputeInstance& v) {
|
||||
return llvm::hash_combine(v.op, v.laneStart, v.laneCount);
|
||||
}
|
||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) {
|
||||
return a == b;
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
@@ -1491,18 +1491,21 @@ void GraphDCP::runDcp() {
|
||||
struct ReadyEntry {
|
||||
Time slack;
|
||||
Time aest;
|
||||
int64_t orderKey;
|
||||
TaskDCP* task;
|
||||
bool operator>(const ReadyEntry& other) const {
|
||||
if (slack != other.slack)
|
||||
return slack > other.slack;
|
||||
if (aest != other.aest)
|
||||
return aest > other.aest;
|
||||
return orderKey > other.orderKey;
|
||||
}
|
||||
};
|
||||
std::priority_queue<ReadyEntry, std::vector<ReadyEntry>, std::greater<ReadyEntry>> readyQueue;
|
||||
size_t readyCount = 0;
|
||||
|
||||
auto pushReady = [&](TaskDCP* node) {
|
||||
readyQueue.push({slackOrZero(node->getAest(), node->getAlst()), node->getAest(), node});
|
||||
readyQueue.push({slackOrZero(node->getAest(), node->getAlst()), node->getAest(), node->Id(), node});
|
||||
};
|
||||
|
||||
for (auto& node : nodes) {
|
||||
@@ -1528,7 +1531,7 @@ void GraphDCP::runDcp() {
|
||||
candidate = entry.task;
|
||||
break;
|
||||
}
|
||||
readyQueue.push({curSlack, curAest, entry.task});
|
||||
readyQueue.push({curSlack, curAest, entry.orderKey, entry.task});
|
||||
}
|
||||
assert(candidate != nullptr && "readyCount > 0 but heap exhausted");
|
||||
--readyCount;
|
||||
@@ -1579,8 +1582,11 @@ DCPAnalysisResult GraphDCP::getResult() {
|
||||
|
||||
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
|
||||
ret.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||
for (auto elem : dominanceOrder)
|
||||
ret.dominanceOrderCompute.push_back(elem->getSpatCompute());
|
||||
for (auto elem : dominanceOrder) {
|
||||
auto spatCompute = elem->getSpatCompute();
|
||||
if (spatCompute)
|
||||
ret.dominanceOrderCompute.push_back({spatCompute.getOperation(), 0});
|
||||
}
|
||||
|
||||
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
|
||||
const CpuTaskList* tasks = findCpuTasks(cpu);
|
||||
@@ -1588,10 +1594,14 @@ DCPAnalysisResult GraphDCP::getResult() {
|
||||
continue;
|
||||
size_t i = 0;
|
||||
for (auto node : *tasks) {
|
||||
ret.computeToCpuMap[node->getSpatCompute()] = cpu;
|
||||
auto spatCompute = node->getSpatCompute();
|
||||
if (!spatCompute)
|
||||
continue;
|
||||
ComputeInstance instance {spatCompute.getOperation(), 0};
|
||||
ret.computeToCpuMap[instance] = cpu;
|
||||
if (i++ == tasks->size() - 1) {
|
||||
ret.isLastComputeOfCpu.insert(node->getSpatCompute());
|
||||
ret.cpuToLastComputeMap[cpu] = node->getSpatCompute();
|
||||
ret.isLastComputeOfCpu.insert(instance);
|
||||
ret.cpuToLastComputeMap[cpu] = instance;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,13 +138,18 @@ public:
|
||||
|
||||
GraphDCP(llvm::ArrayRef<Weight> nodeWeights,
|
||||
llvm::ArrayRef<IndexedEdge> edges,
|
||||
llvm::ArrayRef<int64_t> nodeOrderKeys = {},
|
||||
llvm::ArrayRef<CrossbarUsage> nodeCrossbarUsage = {})
|
||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||
assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size())
|
||||
&& "synthetic crossbar usage must match synthetic node weights");
|
||||
assert((nodeOrderKeys.empty() || nodeOrderKeys.size() == nodeWeights.size())
|
||||
&& "synthetic node order keys must match synthetic node weights");
|
||||
nodes.reserve(nodeWeights.size());
|
||||
for (auto [index, weight] : llvm::enumerate(nodeWeights))
|
||||
nodes.emplace_back(index, weight, nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]);
|
||||
nodes.emplace_back(nodeOrderKeys.empty() ? static_cast<int64_t>(index) : nodeOrderKeys[index],
|
||||
weight,
|
||||
nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]);
|
||||
for (auto [start, end, weight] : edges)
|
||||
makeEdge(start, end, weight);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,8 +23,7 @@ static bool isAddressOnlyHostOp(Operation* op) {
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp,
|
||||
spatial::SpatChannelNewOp>(op);
|
||||
memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value) {
|
||||
@@ -38,6 +37,8 @@ static bool isCodegenAddressableValue(Value value) {
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 0;
|
||||
return false;
|
||||
@@ -69,6 +70,12 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp)))
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||
if (failed(verifyReturnOp(returnOp)))
|
||||
hasFailure = true;
|
||||
@@ -92,10 +99,11 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
}
|
||||
|
||||
private:
|
||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) {
|
||||
bool hasFailure = false;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as memref.get_global before JSON codegen";
|
||||
@@ -131,7 +139,8 @@ private:
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) {
|
||||
return walkPimCoreBlock(
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
bool hasFailure = false;
|
||||
|
||||
Reference in New Issue
Block a user