Compare commits
3 Commits
566630b99a
...
b1272d2283
| Author | SHA1 | Date | |
|---|---|---|---|
| b1272d2283 | |||
| 58e6587697 | |||
| f6c8cc4aa5 |
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
@@ -227,7 +228,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
if (mlir::isa<onnx_mlir::pim::PimEmptyManyOp, mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
||||||
return ResolvedContiguousAddress {value, byteOffset};
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
|||||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||||
|
|||||||
@@ -97,6 +97,11 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||||
gatherMemEntry(allocOp.getResult());
|
gatherMemEntry(allocOp.getResult());
|
||||||
});
|
});
|
||||||
|
funcOp.walk([&](pim::PimEmptyManyOp emptyManyOp) {
|
||||||
|
if (!emptyManyOp->getParentOfType<pim::PimCoreOp>() && !emptyManyOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||||
|
for (mlir::Value output : emptyManyOp.getOutputs())
|
||||||
|
gatherMemEntry(output);
|
||||||
|
});
|
||||||
|
|
||||||
allocateGatheredMemory();
|
allocateGatheredMemory();
|
||||||
|
|
||||||
@@ -106,6 +111,10 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
|
|
||||||
void PimMemory::allocateCore(Operation* op) {
|
void PimMemory::allocateCore(Operation* op) {
|
||||||
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
||||||
|
op->walk([&](pim::PimEmptyManyOp emptyManyOp) {
|
||||||
|
for (mlir::Value output : emptyManyOp.getOutputs())
|
||||||
|
gatherMemEntry(output);
|
||||||
|
});
|
||||||
|
|
||||||
allocateGatheredMemory();
|
allocateGatheredMemory();
|
||||||
}
|
}
|
||||||
@@ -169,7 +178,6 @@ void PimMemory::report(llvm::raw_ostream& file) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void PimMemory::remove(mlir::Value val) {
|
void PimMemory::remove(mlir::Value val) {
|
||||||
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
||||||
globalMemEntriesMap.erase(removeIter);
|
globalMemEntriesMap.erase(removeIter);
|
||||||
@@ -361,11 +369,21 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
|
|||||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp,
|
||||||
for (auto [outputBuffer, sourceCoreId] : llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
|
const StaticValueKnowledge& knowledge) const {
|
||||||
|
for (auto [outputBuffer, sourceCoreId] :
|
||||||
|
llvm::zip(receiveManyOp.getOutputBuffers(), receiveManyOp.getSourceCoreIds()))
|
||||||
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
|
emitCommunicationOp("recv", addressOf(outputBuffer, knowledge), sourceCoreId, getValueSizeInBytes(outputBuffer));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||||
|
const StaticValueKnowledge& knowledge) const {
|
||||||
|
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
||||||
|
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
|
||||||
|
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
||||||
|
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||||
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||||
}
|
}
|
||||||
@@ -375,7 +393,15 @@ void PimCodeGen::codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticVa
|
|||||||
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
|
emitCommunicationOp("send", addressOf(input, knowledge), targetCoreId, getValueSizeInBytes(input));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||||
|
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
||||||
|
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
|
||||||
|
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
||||||
|
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp,
|
||||||
|
const StaticValueKnowledge& knowledge) const {
|
||||||
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
|
auto inputType = cast<ShapedType>(extractRowsOp.getInput().getType());
|
||||||
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
|
assert(inputType.hasStaticShape() && inputType.getRank() == 2 && "extract_rows codegen requires static rank-2 input");
|
||||||
|
|
||||||
@@ -384,13 +410,8 @@ void PimCodeGen::codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const
|
|||||||
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
|
size_t inputAddr = addressOf(extractRowsOp.getInput(), knowledge);
|
||||||
|
|
||||||
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
|
for (auto [rowIndex, outputBuffer] : llvm::enumerate(extractRowsOp.getOutputBuffers()))
|
||||||
emitMemCopyOp("lmv",
|
emitMemCopyOp(
|
||||||
addressOf(outputBuffer, knowledge),
|
"lmv", addressOf(outputBuffer, knowledge), 0, inputAddr, rowIndex * rowSizeInBytes, rowSizeInBytes, "len");
|
||||||
0,
|
|
||||||
inputAddr,
|
|
||||||
rowIndex * rowSizeInBytes,
|
|
||||||
rowSizeInBytes,
|
|
||||||
"len");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
||||||
@@ -733,10 +754,8 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
|
|||||||
for (mlir::Value input : sendManyBatchOp.getInputs())
|
for (mlir::Value input : sendManyBatchOp.getInputs())
|
||||||
mappedInputs.push_back(mapper.lookup(input));
|
mappedInputs.push_back(mapper.lookup(input));
|
||||||
|
|
||||||
pim::PimSendManyOp::create(builder,
|
pim::PimSendManyOp::create(
|
||||||
sendManyBatchOp.getLoc(),
|
builder, sendManyBatchOp.getLoc(), builder.getDenseI32ArrayAttr(laneTargetCoreIds), ValueRange(mappedInputs));
|
||||||
builder.getDenseI32ArrayAttr(laneTargetCoreIds),
|
|
||||||
ValueRange(mappedInputs));
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -764,13 +783,13 @@ static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp cor
|
|||||||
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
|
for (mlir::Value outputBuffer : receiveManyBatchOp.getOutputBuffers())
|
||||||
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
|
mappedOutputBuffers.push_back(mapper.lookup(outputBuffer));
|
||||||
|
|
||||||
auto scalarReceiveMany =
|
auto scalarReceiveMany = pim::PimReceiveManyOp::create(builder,
|
||||||
pim::PimReceiveManyOp::create(builder,
|
receiveManyBatchOp.getLoc(),
|
||||||
receiveManyBatchOp.getLoc(),
|
receiveManyBatchOp->getResultTypes(),
|
||||||
receiveManyBatchOp->getResultTypes(),
|
ValueRange(mappedOutputBuffers),
|
||||||
ValueRange(mappedOutputBuffers),
|
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
|
||||||
builder.getDenseI32ArrayAttr(laneSourceCoreIds));
|
for (auto [originalOutput, scalarOutput] :
|
||||||
for (auto [originalOutput, scalarOutput] : llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
|
llvm::zip(receiveManyBatchOp.getOutputs(), scalarReceiveMany.getOutputs()))
|
||||||
mapper.map(originalOutput, scalarOutput);
|
mapper.map(originalOutput, scalarOutput);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -895,10 +914,14 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
|
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
|
||||||
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
|
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(op))
|
||||||
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
|
coreCodeGen.codeGenReceiveManyOp(receiveManyOp, knowledge);
|
||||||
|
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
|
||||||
|
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
|
||||||
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
||||||
coreCodeGen.codeGenSendOp(sendOp, knowledge);
|
coreCodeGen.codeGenSendOp(sendOp, knowledge);
|
||||||
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
|
else if (auto sendManyOp = dyn_cast<pim::PimSendManyOp>(op))
|
||||||
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
|
coreCodeGen.codeGenSendManyOp(sendManyOp, knowledge);
|
||||||
|
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
|
||||||
|
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
||||||
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
|
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(op))
|
||||||
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
|
coreCodeGen.codeGenExtractRowsOp(extractRowsOp, knowledge);
|
||||||
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||||
@@ -931,6 +954,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
||||||
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
||||||
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
||||||
|
else if (isa<pim::PimEmptyManyOp>(op))
|
||||||
|
return success();
|
||||||
else {
|
else {
|
||||||
op.emitError("Unsupported codegen for this operation");
|
op.emitError("Unsupported codegen for this operation");
|
||||||
op.dump();
|
op.dump();
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
@@ -117,8 +118,10 @@ public:
|
|||||||
|
|
||||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenReceiveManyOp(pim::PimReceiveManyOp receiveManyOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenSendManyOp(pim::PimSendManyOp sendManyOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenExtractRowsOp(pim::PimExtractRowsOp extractRowsOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
||||||
|
|
||||||
|
|||||||
@@ -381,7 +381,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||||
vmmOutputs.push_back(
|
vmmOutputs.push_back(
|
||||||
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||||
if (vmmOutputs.empty()) {
|
if (vmmOutputs.empty()) {
|
||||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||||
return failure();
|
return failure();
|
||||||
@@ -527,7 +527,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||||
rewriter.setInsertionPointToEnd(body);
|
rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||||
Value laneResult = vmmResult;
|
Value laneResult = vmmResult;
|
||||||
if (sharedBias)
|
if (sharedBias)
|
||||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
|
||||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||||
mlir::Value result = operation->getResult(0);
|
mlir::Value result = operation->getResult(0);
|
||||||
auto resultType = result.getType();
|
auto resultType = result.getType();
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
|
|||||||
|
|
||||||
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
|
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
|
||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
|
||||||
|
|
||||||
inline mlir::tensor::EmptyOp
|
inline mlir::tensor::EmptyOp
|
||||||
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ def onnxToPimTranspose : Pat<
|
|||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimVMM : Pat<
|
def spatToPimVMM : Pat<
|
||||||
(SpatWeightedVMMOp:$srcOpRes $weightIndex, $vector),
|
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
|
||||||
(PimVMMOp $weightIndex, $vector,
|
(PimVMMOp $weightIndex, $vector,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def spatToPimMVM : Pat<
|
def spatToPimMVM : Pat<
|
||||||
(SpatWeightedMVMOp:$srcOpRes $weightIndex, $vector),
|
(SpatMVMOp:$srcOpRes $weightIndex, $vector),
|
||||||
(PimMVMOp $weightIndex, $vector,
|
(PimMVMOp $weightIndex, $vector,
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|||||||
@@ -159,9 +159,7 @@ static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRe
|
|||||||
rewriter.eraseOp(sendManyOp);
|
rewriter.eraseOp(sendManyOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter,
|
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter, Location loc, TypeRange outputTypes) {
|
||||||
Location loc,
|
|
||||||
TypeRange outputTypes) {
|
|
||||||
SmallVector<Type> tensorTypes;
|
SmallVector<Type> tensorTypes;
|
||||||
tensorTypes.reserve(outputTypes.size());
|
tensorTypes.reserve(outputTypes.size());
|
||||||
for (Type outputType : outputTypes)
|
for (Type outputType : outputTypes)
|
||||||
@@ -177,7 +175,8 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
|
|||||||
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
|
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
|
||||||
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
|
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
|
||||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||||
SmallVector<Value> outputBuffers = createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
|
SmallVector<Value> outputBuffers =
|
||||||
|
createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
|
||||||
|
|
||||||
auto receiveMany = PimReceiveManyOp::create(rewriter,
|
auto receiveMany = PimReceiveManyOp::create(rewriter,
|
||||||
receiveManyOp.getLoc(),
|
receiveManyOp.getLoc(),
|
||||||
@@ -199,10 +198,8 @@ static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendMa
|
|||||||
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
|
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
|
||||||
for (Value input : sendManyBatchOp.getInputs())
|
for (Value input : sendManyBatchOp.getInputs())
|
||||||
mappedInputs.push_back(mapper.lookup(input));
|
mappedInputs.push_back(mapper.lookup(input));
|
||||||
pim::PimSendManyBatchOp::create(rewriter,
|
pim::PimSendManyBatchOp::create(
|
||||||
sendManyBatchOp.getLoc(),
|
rewriter, sendManyBatchOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), ValueRange(mappedInputs));
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
|
||||||
ValueRange(mappedInputs));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
||||||
@@ -252,25 +249,6 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
|||||||
rewriter.replaceOp(concatOp, concatenated);
|
rewriter.replaceOp(concatOp, concatenated);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
||||||
SmallVector<spatial::SpatWeightedVMMOp> wvmmOps;
|
|
||||||
funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) {
|
|
||||||
if (wvmmOp->getParentOfType<pim::PimCoreOp>() || wvmmOp->getParentOfType<pim::PimCoreBatchOp>())
|
|
||||||
wvmmOps.push_back(wvmmOp);
|
|
||||||
});
|
|
||||||
|
|
||||||
for (auto wvmmOp : wvmmOps) {
|
|
||||||
rewriter.setInsertionPoint(wvmmOp);
|
|
||||||
auto outputType = cast<ShapedType>(wvmmOp.getOutput().getType());
|
|
||||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult();
|
|
||||||
rewriter.replaceOpWithNewOp<pim::PimVMMOp>(wvmmOp,
|
|
||||||
wvmmOp.getOutput().getType(),
|
|
||||||
rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()),
|
|
||||||
wvmmOp.getInput(),
|
|
||||||
outputBuffer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
SmallVector<spatial::SpatMapOp> mapOps;
|
SmallVector<spatial::SpatMapOp> mapOps;
|
||||||
funcOp.walk([&](spatial::SpatMapOp mapOp) {
|
funcOp.walk([&](spatial::SpatMapOp mapOp) {
|
||||||
@@ -291,6 +269,276 @@ static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||||
|
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||||
|
packedShape[0] *= count;
|
||||||
|
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) {
|
||||||
|
if (values.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto firstResult = dyn_cast<OpResult>(values.front());
|
||||||
|
if (!firstResult)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
owner = firstResult.getOwner();
|
||||||
|
startIndex = firstResult.getResultNumber();
|
||||||
|
for (auto [index, value] : llvm::enumerate(values)) {
|
||||||
|
auto result = dyn_cast<OpResult>(value);
|
||||||
|
if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createPackedExtractRowsSlice(
|
||||||
|
pim::PimExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||||
|
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||||
|
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
int64_t rowsPerValue = rowType.getDimSize(0);
|
||||||
|
if (ShapedType::isDynamic(rowsPerValue))
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||||
|
SmallVector<OpFoldResult> offsets;
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
SmallVector<OpFoldResult> strides;
|
||||||
|
offsets.reserve(inputType.getRank());
|
||||||
|
sizes.reserve(inputType.getRank());
|
||||||
|
strides.reserve(inputType.getRank());
|
||||||
|
|
||||||
|
offsets.push_back(rewriter.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
||||||
|
strides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||||
|
offsets.push_back(rewriter.getIndexAttr(0));
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||||
|
strides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) {
|
||||||
|
Operation* owner = nullptr;
|
||||||
|
unsigned startIndex = 0;
|
||||||
|
if (!getContiguousOpResults(values, owner, startIndex))
|
||||||
|
return {};
|
||||||
|
|
||||||
|
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
|
||||||
|
return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createPackedReceiveTensor(
|
||||||
|
pim::PimReceiveManyOp receiveManyOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||||
|
auto rowType = dyn_cast<RankedTensorType>(receiveManyOp.getOutputs()[startIndex].getType());
|
||||||
|
if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
||||||
|
auto outputBuffer = tensor::EmptyOp::create(rewriter, loc, packedType.getShape(), packedType.getElementType());
|
||||||
|
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
sourceCoreIds.reserve(count);
|
||||||
|
ArrayRef<int32_t> allSourceCoreIds = receiveManyOp.getSourceCoreIds();
|
||||||
|
for (unsigned index = 0; index < count; ++index)
|
||||||
|
sourceCoreIds.push_back(allSourceCoreIds[startIndex + index]);
|
||||||
|
|
||||||
|
return pim::PimReceiveTensorOp::create(
|
||||||
|
rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value
|
||||||
|
createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||||
|
Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc);
|
||||||
|
if (!packedInput)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(mapOp.getInputs()[startIndex].getType());
|
||||||
|
auto outputType = dyn_cast<RankedTensorType>(mapOp.getOutputs()[startIndex].getType());
|
||||||
|
if (!inputType || !outputType || !inputType.hasStaticShape() || !outputType.hasStaticShape()
|
||||||
|
|| inputType.getRank() == 0 || outputType.getRank() == 0)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(count));
|
||||||
|
auto packedInit =
|
||||||
|
tensor::EmptyOp::create(rewriter, loc, packedOutputType.getShape(), packedOutputType.getElementType());
|
||||||
|
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
|
auto upper = arith::ConstantIndexOp::create(rewriter, loc, count);
|
||||||
|
auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||||
|
auto loop = scf::ForOp::create(rewriter, loc, zero, upper, step, ValueRange {packedInit.getResult()});
|
||||||
|
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
Block* loopBlock = loop.getBody();
|
||||||
|
rewriter.setInsertionPointToStart(loopBlock);
|
||||||
|
Value iv = loopBlock->getArgument(0);
|
||||||
|
Value acc = loopBlock->getArgument(1);
|
||||||
|
|
||||||
|
int64_t inputRowsPerValue = inputType.getDimSize(0);
|
||||||
|
Value inputRowOffset = iv;
|
||||||
|
if (inputRowsPerValue != 1) {
|
||||||
|
auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, inputRowsPerValue);
|
||||||
|
inputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> extractOffsets;
|
||||||
|
SmallVector<OpFoldResult> extractSizes;
|
||||||
|
SmallVector<OpFoldResult> extractStrides;
|
||||||
|
extractOffsets.push_back(inputRowOffset);
|
||||||
|
extractSizes.push_back(rewriter.getIndexAttr(inputRowsPerValue));
|
||||||
|
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
||||||
|
extractOffsets.push_back(rewriter.getIndexAttr(0));
|
||||||
|
extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim)));
|
||||||
|
extractStrides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
auto inputSlice = tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, inputType, packedInput, extractOffsets, extractSizes, extractStrides);
|
||||||
|
|
||||||
|
IRMapping mapping;
|
||||||
|
Block& body = mapOp.getBody().front();
|
||||||
|
mapping.map(body.getArgument(0), inputSlice.getResult());
|
||||||
|
for (Operation& bodyOp : body.without_terminator()) {
|
||||||
|
Operation* cloned = rewriter.clone(bodyOp, mapping);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults()))
|
||||||
|
mapping.map(originalResult, clonedResult);
|
||||||
|
rewriter.setInsertionPointAfter(cloned);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
|
||||||
|
Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||||
|
|
||||||
|
int64_t outputRowsPerValue = outputType.getDimSize(0);
|
||||||
|
Value outputRowOffset = iv;
|
||||||
|
if (outputRowsPerValue != 1) {
|
||||||
|
auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, outputRowsPerValue);
|
||||||
|
outputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> insertOffsets;
|
||||||
|
SmallVector<OpFoldResult> insertSizes;
|
||||||
|
SmallVector<OpFoldResult> insertStrides;
|
||||||
|
insertOffsets.push_back(outputRowOffset);
|
||||||
|
insertSizes.push_back(rewriter.getIndexAttr(outputRowsPerValue));
|
||||||
|
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
for (int64_t dim = 1; dim < outputType.getRank(); ++dim) {
|
||||||
|
insertOffsets.push_back(rewriter.getIndexAttr(0));
|
||||||
|
insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim)));
|
||||||
|
insertStrides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inserted =
|
||||||
|
tensor::InsertSliceOp::create(rewriter, loc, mappedOutput, acc, insertOffsets, insertSizes, insertStrides);
|
||||||
|
scf::YieldOp::create(rewriter, loc, inserted.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
return loop.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
|
SmallVector<pim::PimSendManyOp> sendManyOps;
|
||||||
|
funcOp.walk([&](pim::PimSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); });
|
||||||
|
for (auto sendManyOp : sendManyOps) {
|
||||||
|
if (sendManyOp.getInputs().empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(sendManyOp);
|
||||||
|
Value packedInput = createPackedTensorForValues(sendManyOp.getInputs(), rewriter, sendManyOp.getLoc());
|
||||||
|
if (!packedInput)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
pim::PimSendTensorOp::create(rewriter, sendManyOp.getLoc(), packedInput, sendManyOp.getTargetCoreIdsAttr());
|
||||||
|
rewriter.eraseOp(sendManyOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<pim::PimConcatOp> concatOps;
|
||||||
|
funcOp.walk([&](pim::PimConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||||
|
for (auto concatOp : concatOps) {
|
||||||
|
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SmallVector<Value> packedInputs;
|
||||||
|
bool changed = false;
|
||||||
|
rewriter.setInsertionPoint(concatOp);
|
||||||
|
|
||||||
|
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
||||||
|
Value input = concatOp.getInputs()[index];
|
||||||
|
auto result = dyn_cast<OpResult>(input);
|
||||||
|
if (!result) {
|
||||||
|
packedInputs.push_back(input);
|
||||||
|
++index;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* owner = result.getOwner();
|
||||||
|
unsigned startIndex = result.getResultNumber();
|
||||||
|
unsigned endIndex = index + 1;
|
||||||
|
while (endIndex < concatOp.getInputs().size()) {
|
||||||
|
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
||||||
|
if (!nextResult || nextResult.getOwner() != owner
|
||||||
|
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
||||||
|
break;
|
||||||
|
++endIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned count = endIndex - index;
|
||||||
|
Value packedInput;
|
||||||
|
if (auto mapOp = dyn_cast<pim::PimMapOp>(owner))
|
||||||
|
packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||||
|
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(owner))
|
||||||
|
packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||||
|
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
|
||||||
|
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||||
|
|
||||||
|
if (packedInput) {
|
||||||
|
packedInputs.push_back(packedInput);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
||||||
|
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
index = endIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!changed)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto newConcat = pim::PimConcatOp::create(rewriter,
|
||||||
|
concatOp.getLoc(),
|
||||||
|
concatOp.getOutput().getType(),
|
||||||
|
concatOp.getAxisAttr(),
|
||||||
|
ValueRange(packedInputs),
|
||||||
|
concatOp.getOutputBuffer());
|
||||||
|
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto eraseUnusedOps = [&](auto tag) {
|
||||||
|
using OpTy = decltype(tag);
|
||||||
|
SmallVector<OpTy> ops;
|
||||||
|
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||||
|
for (auto op : llvm::reverse(ops))
|
||||||
|
if (op->use_empty())
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
};
|
||||||
|
eraseUnusedOps(pim::PimMapOp {});
|
||||||
|
eraseUnusedOps(pim::PimReceiveManyOp {});
|
||||||
|
eraseUnusedOps(pim::PimExtractRowsOp {});
|
||||||
|
eraseUnusedOps(pim::PimEmptyManyOp {});
|
||||||
|
}
|
||||||
|
|
||||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||||
SmallVectorImpl<Operation*>& helperChain,
|
SmallVectorImpl<Operation*>& helperChain,
|
||||||
bool requireReturnUse = true) {
|
bool requireReturnUse = true) {
|
||||||
@@ -418,21 +666,21 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||||
auto getConcatResult = [](Operation *op) -> Value {
|
auto getConcatResult = [](Operation* op) -> Value {
|
||||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||||
return tensorConcat.getResult();
|
return tensorConcat.getResult();
|
||||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||||
return pimConcat.getOutput();
|
return pimConcat.getOutput();
|
||||||
return {};
|
return {};
|
||||||
};
|
};
|
||||||
auto getConcatAxis = [](Operation *op) -> std::optional<int64_t> {
|
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
|
||||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||||
return tensorConcat.getDim();
|
return tensorConcat.getDim();
|
||||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||||
return pimConcat.getAxis();
|
return pimConcat.getAxis();
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
};
|
};
|
||||||
auto getConcatOperands = [](Operation *op) -> OperandRange {
|
auto getConcatOperands = [](Operation* op) -> OperandRange {
|
||||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||||
return tensorConcat.getOperands();
|
return tensorConcat.getOperands();
|
||||||
return cast<pim::PimConcatOp>(op).getInputs();
|
return cast<pim::PimConcatOp>(op).getInputs();
|
||||||
@@ -736,7 +984,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
SmallVector<pim::PimCoreOp> coreOps;
|
SmallVector<pim::PimCoreOp> coreOps;
|
||||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||||
for (auto coreOp : coreOps) {
|
for (auto coreOp : coreOps) {
|
||||||
if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) {
|
if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -745,15 +993,13 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||||
for (auto coreBatchOp : coreBatchOps) {
|
for (auto coreBatchOp : coreBatchOps) {
|
||||||
if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) {
|
if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lowerRemainingSpatialMathOps(funcOp, rewriter);
|
|
||||||
|
|
||||||
RewritePatternSet channelPatterns(ctx);
|
RewritePatternSet channelPatterns(ctx);
|
||||||
populateWithGenerated(channelPatterns);
|
populateWithGenerated(channelPatterns);
|
||||||
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) {
|
||||||
@@ -820,6 +1066,8 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
for (auto extractRowsOp : remainingExtractRowsOps)
|
for (auto extractRowsOp : remainingExtractRowsOps)
|
||||||
lowerExtractRows(extractRowsOp, rewriter);
|
lowerExtractRows(extractRowsOp, rewriter);
|
||||||
|
|
||||||
|
compactPimTensorGroups(funcOp, rewriter);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
bool hasSpatialOps = false;
|
bool hasSpatialOps = false;
|
||||||
moduleOp.walk([&](Operation* op) {
|
moduleOp.walk([&](Operation* op) {
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ def PimEmptyManyOp : PimOp<"empty_many", []> {
|
|||||||
let summary = "Create many identical empty tensors";
|
let summary = "Create many identical empty tensors";
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
Variadic<AnyRankedTensor>:$outputs
|
Variadic<PimTensor>:$outputs
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
@@ -133,6 +133,18 @@ def PimSendManyOp : PimOp<"send_many", []> {
|
|||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def PimSendTensorOp : PimOp<"send_tensor", []> {
|
||||||
|
let summary = "Send equal contiguous chunks of one tensor to target cores";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$input,
|
||||||
|
DenseI32ArrayAttr:$targetCoreIds
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def PimSendBatchOp : PimOp<"send_batch", []> {
|
def PimSendBatchOp : PimOp<"send_batch", []> {
|
||||||
let summary = "Send a per-lane tensor to target cores from a batched core";
|
let summary = "Send a per-lane tensor to target cores from a batched core";
|
||||||
|
|
||||||
@@ -203,6 +215,28 @@ def PimReceiveManyOp : PimOp<"receive_many", [DestinationStyleOpInterface]> {
|
|||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def PimReceiveTensorOp : PimOp<"receive_tensor", [DestinationStyleOpInterface]> {
|
||||||
|
let summary = "Receive equal contiguous chunks from source cores into one tensor";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
PimTensor:$outputBuffer,
|
||||||
|
DenseI32ArrayAttr:$sourceCoreIds
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
PimTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||||
|
return getOutputBufferMutable();
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Receive per-lane tensors from source cores into a batched core";
|
let summary = "Receive per-lane tensors from source cores into a batched core";
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,8 @@
|
|||||||
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -100,9 +100,9 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
|
|
||||||
auto& builder = parser.getBuilder();
|
auto& builder = parser.getBuilder();
|
||||||
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount));
|
||||||
result.addAttribute("operandSegmentSizes",
|
result.addAttribute(
|
||||||
builder.getDenseI32ArrayAttr(
|
"operandSegmentSizes",
|
||||||
{static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights.size()), static_cast<int32_t>(inputs.size())}));
|
||||||
if (hasCoreIds)
|
if (hasCoreIds)
|
||||||
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds));
|
||||||
|
|
||||||
@@ -267,6 +267,33 @@ ParseResult PimSendManyOp::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimSendTensorOp::print(OpAsmPrinter& printer) {
|
||||||
|
printer << " ";
|
||||||
|
printer.printOperand(getInput());
|
||||||
|
printCoreIdList(printer, "to", getTargetCoreIds());
|
||||||
|
printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()});
|
||||||
|
printer << " : ";
|
||||||
|
printer.printType(getInput().getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult PimSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
OpAsmParser::UnresolvedOperand input;
|
||||||
|
Type inputType;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
|
||||||
|
if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds)
|
||||||
|
|| parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds"))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"targetCoreIds cannot be specified both positionally and in attr-dict");
|
||||||
|
if (!targetCoreIds.empty())
|
||||||
|
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||||
|
|
||||||
|
return parser.resolveOperand(input, inputType, result.operands);
|
||||||
|
}
|
||||||
|
|
||||||
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
|
void PimSendManyBatchOp::print(OpAsmPrinter& printer) {
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printCompressedValueSequence(printer, getInputs());
|
printCompressedValueSequence(printer, getInputs());
|
||||||
@@ -333,6 +360,43 @@ ParseResult PimReceiveManyOp::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimReceiveTensorOp::print(OpAsmPrinter& printer) {
|
||||||
|
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||||
|
printer << " into ";
|
||||||
|
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
printer.printOperand(getOutputBuffer());
|
||||||
|
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||||
|
printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()});
|
||||||
|
printer << " : ";
|
||||||
|
printer.printType(getOutputBuffer().getType());
|
||||||
|
printer << " -> ";
|
||||||
|
printer.printType(getOutput().getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult PimReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||||
|
OpAsmParser::UnresolvedOperand outputBuffer;
|
||||||
|
Type outputBufferType;
|
||||||
|
Type outputType;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
|
||||||
|
if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen()
|
||||||
|
|| parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes)
|
||||||
|
|| parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow()
|
||||||
|
|| parser.parseType(outputType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds"))
|
||||||
|
return parser.emitError(parser.getCurrentLocation(),
|
||||||
|
"sourceCoreIds cannot be specified both positionally and in attr-dict");
|
||||||
|
if (!sourceCoreIds.empty())
|
||||||
|
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||||
|
|
||||||
|
if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(outputType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
|
void PimReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||||
printCoreIdList(printer, "from", getSourceCoreIds());
|
printCoreIdList(printer, "from", getSourceCoreIds());
|
||||||
printer << " into ";
|
printer << " into ";
|
||||||
|
|||||||
@@ -48,12 +48,32 @@ static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types
|
|||||||
return op->emitError() << kind << " values must all have the same type";
|
return op->emitError() << kind << " values must all have the same type";
|
||||||
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
|
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
|
||||||
return op->emitError() << kind << " values must all use the same shaped container kind";
|
return op->emitError() << kind << " values must all use the same shaped container kind";
|
||||||
if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape())
|
if (firstShapedType.getElementType() != shapedType.getElementType()
|
||||||
|
|| firstShapedType.getShape() != shapedType.getShape())
|
||||||
return op->emitError() << kind << " values must all have the same shape and element type";
|
return op->emitError() << kind << " values must all have the same shape and element type";
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||||
|
if (coreIds.empty())
|
||||||
|
return op->emitError() << kind << " must carry at least one chunk";
|
||||||
|
|
||||||
|
auto shapedType = dyn_cast<ShapedType>(type);
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape())
|
||||||
|
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
||||||
|
|
||||||
|
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||||
|
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||||
|
return op->emitError() << kind << " requires byte-sized elements";
|
||||||
|
|
||||||
|
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||||
|
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
|
||||||
|
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
||||||
if (!coreBatchOp)
|
if (!coreBatchOp)
|
||||||
@@ -61,9 +81,7 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
|||||||
return coreBatchOp.getLaneCount();
|
return coreBatchOp.getLaneCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op,
|
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
|
||||||
ArrayRef<int32_t> coreIds,
|
|
||||||
size_t valueCount) {
|
|
||||||
auto laneCount = getParentBatchLaneCount(op);
|
auto laneCount = getParentBatchLaneCount(op);
|
||||||
if (failed(laneCount))
|
if (failed(laneCount))
|
||||||
return op->emitError("must be nested inside pim.core_batch");
|
return op->emitError("must be nested inside pim.core_batch");
|
||||||
@@ -79,9 +97,9 @@ LogicalResult PimEmptyManyOp::verify() {
|
|||||||
return emitError("must produce at least one output");
|
return emitError("must produce at least one output");
|
||||||
|
|
||||||
Type firstType = getOutputs().front().getType();
|
Type firstType = getOutputs().front().getType();
|
||||||
auto firstTensorType = dyn_cast<RankedTensorType>(firstType);
|
auto firstShapedType = dyn_cast<ShapedType>(firstType);
|
||||||
if (!firstTensorType)
|
if (!firstShapedType || !firstShapedType.hasRank())
|
||||||
return emitError("outputs must all be ranked tensor types");
|
return emitError("outputs must all be ranked shaped types");
|
||||||
|
|
||||||
for (Value output : getOutputs().drop_front())
|
for (Value output : getOutputs().drop_front())
|
||||||
if (output.getType() != firstType)
|
if (output.getType() != firstType)
|
||||||
@@ -109,7 +127,8 @@ LogicalResult PimMapOp::verify() {
|
|||||||
Block& block = getBody().front();
|
Block& block = getBody().front();
|
||||||
if (block.getNumArguments() != 1)
|
if (block.getNumArguments() != 1)
|
||||||
return emitError("body must have exactly one block argument");
|
return emitError("body must have exactly one block argument");
|
||||||
if (block.getArgument(0).getType() != inputType)
|
if (failed(verifyCompatibleShapedTypes(
|
||||||
|
getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type")))
|
||||||
return emitError("body block argument type must match input type");
|
return emitError("body block argument type must match input type");
|
||||||
|
|
||||||
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
|
||||||
@@ -117,7 +136,8 @@ LogicalResult PimMapOp::verify() {
|
|||||||
return emitError("body must terminate with pim.yield");
|
return emitError("body must terminate with pim.yield");
|
||||||
if (yieldOp.getNumOperands() != 1)
|
if (yieldOp.getNumOperands() != 1)
|
||||||
return emitError("body yield must produce exactly one value");
|
return emitError("body yield must produce exactly one value");
|
||||||
if (yieldOp.getOperand(0).getType() != outputType)
|
if (failed(verifyCompatibleShapedTypes(
|
||||||
|
getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type")))
|
||||||
return emitError("body yield type must match output type");
|
return emitError("body yield type must match output type");
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@@ -129,6 +149,10 @@ LogicalResult PimSendManyOp::verify() {
|
|||||||
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
|
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult PimSendTensorOp::verify() {
|
||||||
|
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult PimSendManyBatchOp::verify() {
|
LogicalResult PimSendManyBatchOp::verify() {
|
||||||
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
|
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
|
||||||
return failure();
|
return failure();
|
||||||
@@ -153,6 +177,14 @@ LogicalResult PimReceiveManyOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult PimReceiveTensorOp::verify() {
|
||||||
|
if (failed(verifyCompatibleShapedTypes(
|
||||||
|
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult PimReceiveManyBatchOp::verify() {
|
LogicalResult PimReceiveManyBatchOp::verify() {
|
||||||
if (getOutputBuffers().size() != getOutputs().size())
|
if (getOutputBuffers().size() != getOutputs().size())
|
||||||
return emitError("number of output buffers must match the number of outputs");
|
return emitError("number of output buffers must match the number of outputs");
|
||||||
|
|||||||
@@ -34,6 +34,13 @@ static Value materializeContiguousMemRef(Value memrefValue, Location loc, Rewrit
|
|||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value>
|
||||||
|
getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) {
|
||||||
|
if (isa<BufferLikeType>(value.getType()))
|
||||||
|
return value;
|
||||||
|
return getBuffer(rewriter, value, options, state);
|
||||||
|
}
|
||||||
|
|
||||||
struct MemCopyHostToDevOpInterface
|
struct MemCopyHostToDevOpInterface
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
@@ -44,12 +51,12 @@ struct MemCopyHostToDevOpInterface
|
|||||||
auto deviceTarget = memCopyHostToDevOp.getDeviceTarget();
|
auto deviceTarget = memCopyHostToDevOp.getDeviceTarget();
|
||||||
auto hostSource = memCopyHostToDevOp.getHostSource();
|
auto hostSource = memCopyHostToDevOp.getHostSource();
|
||||||
|
|
||||||
auto deviceTargetOpt = getBuffer(rewriter, deviceTarget, options, state);
|
auto deviceTargetOpt = getBufferOrValue(rewriter, deviceTarget, options, state);
|
||||||
if (failed(deviceTargetOpt))
|
if (failed(deviceTargetOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto deviceTargetMemRef = *deviceTargetOpt;
|
auto deviceTargetMemRef = *deviceTargetOpt;
|
||||||
|
|
||||||
auto hostSourceOpt = getBuffer(rewriter, hostSource, options, state);
|
auto hostSourceOpt = getBufferOrValue(rewriter, hostSource, options, state);
|
||||||
if (failed(hostSourceOpt))
|
if (failed(hostSourceOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto hostSourceMemRef = *hostSourceOpt;
|
auto hostSourceMemRef = *hostSourceOpt;
|
||||||
@@ -73,10 +80,10 @@ struct MemCopyHostToDevBatchOpInterface
|
|||||||
const BufferizationOptions& options,
|
const BufferizationOptions& options,
|
||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
|
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
|
||||||
auto deviceTargetOpt = getBuffer(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
|
auto deviceTargetOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
|
||||||
if (failed(deviceTargetOpt))
|
if (failed(deviceTargetOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto hostSourceOpt = getBuffer(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
|
auto hostSourceOpt = getBufferOrValue(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
|
||||||
if (failed(hostSourceOpt))
|
if (failed(hostSourceOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -101,13 +108,13 @@ struct MemCopyDevToHostOpInterface
|
|||||||
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
||||||
|
|
||||||
auto hostTarget = memCopyDevToHostOp.getHostTarget();
|
auto hostTarget = memCopyDevToHostOp.getHostTarget();
|
||||||
auto hostTargetOpt = getBuffer(rewriter, hostTarget, options, state);
|
auto hostTargetOpt = getBufferOrValue(rewriter, hostTarget, options, state);
|
||||||
if (failed(hostTargetOpt))
|
if (failed(hostTargetOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto hostTargetMemRef = *hostTargetOpt;
|
auto hostTargetMemRef = *hostTargetOpt;
|
||||||
|
|
||||||
auto deviceSource = memCopyDevToHostOp.getDeviceSource();
|
auto deviceSource = memCopyDevToHostOp.getDeviceSource();
|
||||||
auto deviceSourceOpt = getBuffer(rewriter, deviceSource, options, state);
|
auto deviceSourceOpt = getBufferOrValue(rewriter, deviceSource, options, state);
|
||||||
if (failed(deviceSourceOpt))
|
if (failed(deviceSourceOpt))
|
||||||
return failure();
|
return failure();
|
||||||
auto deviceSourceMemRef = *deviceSourceOpt;
|
auto deviceSourceMemRef = *deviceSourceOpt;
|
||||||
@@ -135,7 +142,7 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto receiveOp = cast<PimReceiveOp>(op);
|
auto receiveOp = cast<PimReceiveOp>(op);
|
||||||
|
|
||||||
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -159,7 +166,7 @@ struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<Receive
|
|||||||
const BufferizationOptions& options,
|
const BufferizationOptions& options,
|
||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto receiveOp = cast<PimReceiveBatchOp>(op);
|
auto receiveOp = cast<PimReceiveBatchOp>(op);
|
||||||
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -185,30 +192,44 @@ struct ReceiveManyOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveM
|
|||||||
auto receiveOp = cast<PimReceiveManyOp>(op);
|
auto receiveOp = cast<PimReceiveManyOp>(op);
|
||||||
SmallVector<Value> outputBuffers;
|
SmallVector<Value> outputBuffers;
|
||||||
SmallVector<Type> resultTypes;
|
SmallVector<Type> resultTypes;
|
||||||
SmallVector<Value> tensorResults;
|
|
||||||
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
|
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
|
||||||
resultTypes.reserve(receiveOp.getOutputBuffers().size());
|
resultTypes.reserve(receiveOp.getOutputBuffers().size());
|
||||||
tensorResults.reserve(receiveOp.getOutputBuffers().size());
|
|
||||||
|
|
||||||
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
|
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
|
||||||
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
outputBuffers.push_back(*outputBufferOpt);
|
outputBuffers.push_back(*outputBufferOpt);
|
||||||
resultTypes.push_back(outputBufferOpt->getType());
|
resultTypes.push_back(outputBufferOpt->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newOp = PimReceiveManyOp::create(
|
auto newOp = PimReceiveManyOp::create(rewriter,
|
||||||
rewriter, receiveOp.getLoc(), TypeRange(resultTypes), ValueRange(outputBuffers), receiveOp.getSourceCoreIdsAttr());
|
receiveOp.getLoc(),
|
||||||
|
TypeRange(resultTypes),
|
||||||
|
ValueRange(outputBuffers),
|
||||||
|
receiveOp.getSourceCoreIdsAttr());
|
||||||
|
rewriter.replaceOp(receiveOp, newOp.getOutputs());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) {
|
struct ReceiveTensorOpInterface
|
||||||
auto tensorType = cast<RankedTensorType>(tensorResult.getType());
|
: DstBufferizableOpInterfaceExternalModel<ReceiveTensorOpInterface, PimReceiveTensorOp> {
|
||||||
auto toTensor =
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr());
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
tensorResults.push_back(toTensor.getResult());
|
}
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(receiveOp, tensorResults);
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto receiveOp = cast<PimReceiveTensorOp>(op);
|
||||||
|
auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimReceiveTensorOp>(
|
||||||
|
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -226,13 +247,11 @@ struct ReceiveManyBatchOpInterface
|
|||||||
auto receiveOp = cast<PimReceiveManyBatchOp>(op);
|
auto receiveOp = cast<PimReceiveManyBatchOp>(op);
|
||||||
SmallVector<Value> outputBuffers;
|
SmallVector<Value> outputBuffers;
|
||||||
SmallVector<Type> resultTypes;
|
SmallVector<Type> resultTypes;
|
||||||
SmallVector<Value> tensorResults;
|
|
||||||
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
|
outputBuffers.reserve(receiveOp.getOutputBuffers().size());
|
||||||
resultTypes.reserve(receiveOp.getOutputBuffers().size());
|
resultTypes.reserve(receiveOp.getOutputBuffers().size());
|
||||||
tensorResults.reserve(receiveOp.getOutputBuffers().size());
|
|
||||||
|
|
||||||
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
|
for (Value outputBuffer : receiveOp.getOutputBuffers()) {
|
||||||
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
outputBuffers.push_back(*outputBufferOpt);
|
outputBuffers.push_back(*outputBufferOpt);
|
||||||
@@ -244,15 +263,7 @@ struct ReceiveManyBatchOpInterface
|
|||||||
TypeRange(resultTypes),
|
TypeRange(resultTypes),
|
||||||
ValueRange(outputBuffers),
|
ValueRange(outputBuffers),
|
||||||
receiveOp.getSourceCoreIdsAttr());
|
receiveOp.getSourceCoreIdsAttr());
|
||||||
|
rewriter.replaceOp(receiveOp, newOp.getOutputs());
|
||||||
for (auto [bufferResult, tensorResult] : llvm::zip(newOp.getOutputs(), receiveOp.getOutputs())) {
|
|
||||||
auto tensorType = cast<RankedTensorType>(tensorResult.getType());
|
|
||||||
auto toTensor =
|
|
||||||
bufferization::ToTensorOp::create(rewriter, receiveOp.getLoc(), tensorType, bufferResult, UnitAttr(), UnitAttr());
|
|
||||||
tensorResults.push_back(toTensor.getResult());
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.replaceOp(receiveOp, tensorResults);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -267,7 +278,7 @@ struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel<ExtractR
|
|||||||
const BufferizationOptions& options,
|
const BufferizationOptions& options,
|
||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto extractRowsOp = cast<PimExtractRowsOp>(op);
|
auto extractRowsOp = cast<PimExtractRowsOp>(op);
|
||||||
auto inputOpt = getBuffer(rewriter, extractRowsOp.getInput(), options, state);
|
auto inputOpt = getBufferOrValue(rewriter, extractRowsOp.getInput(), options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -277,7 +288,7 @@ struct ExtractRowsOpInterface : DstBufferizableOpInterfaceExternalModel<ExtractR
|
|||||||
resultTypes.reserve(extractRowsOp.getOutputBuffers().size());
|
resultTypes.reserve(extractRowsOp.getOutputBuffers().size());
|
||||||
|
|
||||||
for (Value outputBuffer : extractRowsOp.getOutputBuffers()) {
|
for (Value outputBuffer : extractRowsOp.getOutputBuffers()) {
|
||||||
auto outputBufferOpt = getBuffer(rewriter, outputBuffer, options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, outputBuffer, options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
outputBuffers.push_back(*outputBufferOpt);
|
outputBuffers.push_back(*outputBufferOpt);
|
||||||
@@ -307,13 +318,13 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
|||||||
SmallVector<Value> inputs;
|
SmallVector<Value> inputs;
|
||||||
inputs.reserve(concatOp.getInputs().size());
|
inputs.reserve(concatOp.getInputs().size());
|
||||||
for (Value input : concatOp.getInputs()) {
|
for (Value input : concatOp.getInputs()) {
|
||||||
auto inputOpt = getBuffer(rewriter, input, options, state);
|
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter));
|
inputs.push_back(materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outputBufferOpt = getBuffer(rewriter, concatOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, concatOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -323,6 +334,55 @@ struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel<ConcatOpInter
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct EmptyManyOpInterface : BufferizableOpInterface::ExternalModel<EmptyManyOpInterface, PimEmptyManyOp> {
|
||||||
|
bool bufferizesToAllocation(Operation* op, Value value) const { return true; }
|
||||||
|
|
||||||
|
bool resultBufferizesToMemoryWrite(Operation* op, OpResult opResult, const AnalysisState& state) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto emptyManyOp = cast<PimEmptyManyOp>(op);
|
||||||
|
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
resultTypes.reserve(emptyManyOp.getOutputs().size());
|
||||||
|
for (Value output : emptyManyOp.getOutputs()) {
|
||||||
|
auto shapedType = cast<ShapedType>(output.getType());
|
||||||
|
resultTypes.push_back(MemRefType::get(shapedType.getShape(), shapedType.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimEmptyManyOp>(rewriter, emptyManyOp, TypeRange(resultTypes));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SendTensorOpInterface : BufferizableOpInterface::ExternalModel<SendTensorOpInterface, PimSendTensorOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||||
|
|
||||||
|
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||||
|
|
||||||
|
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto sendOp = cast<PimSendTensorOp>(op);
|
||||||
|
auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state);
|
||||||
|
if (failed(inputOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimSendTensorOp>(
|
||||||
|
rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> {
|
struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, PimMapOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||||
|
|
||||||
@@ -335,23 +395,26 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
|
|||||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||||
auto mapOp = cast<PimMapOp>(op);
|
auto mapOp = cast<PimMapOp>(op);
|
||||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||||
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty())
|
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|
||||||
|
|| mapOp.getInputs().empty())
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
return {{&mapOp->getOpOperand(0), BufferRelation::Equivalent}};
|
return {
|
||||||
|
{&mapOp->getOpOperand(0), BufferRelation::Equivalent}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||||
|
|
||||||
FailureOr<BufferLikeType>
|
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||||
getBufferType(Operation* op,
|
Value value,
|
||||||
Value value,
|
const BufferizationOptions& options,
|
||||||
const BufferizationOptions& options,
|
const BufferizationState& state,
|
||||||
const BufferizationState& state,
|
SmallVector<Value>& invocationStack) const {
|
||||||
SmallVector<Value>& invocationStack) const {
|
|
||||||
auto mapOp = cast<PimMapOp>(op);
|
auto mapOp = cast<PimMapOp>(op);
|
||||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||||
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0 || mapOp.getInputs().empty())
|
if (!bbArg || bbArg.getOwner() != &mapOp.getBody().front() || bbArg.getArgNumber() != 0
|
||||||
|
|| mapOp.getInputs().empty())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto inputType = dyn_cast<BufferLikeType>(mapOp.getInputs().front().getType());
|
auto inputType = dyn_cast<BufferLikeType>(mapOp.getInputs().front().getType());
|
||||||
@@ -375,7 +438,7 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
|
|||||||
|
|
||||||
for (Value input : mapOp.getInputs()) {
|
for (Value input : mapOp.getInputs()) {
|
||||||
if (isa<TensorType>(input.getType())) {
|
if (isa<TensorType>(input.getType())) {
|
||||||
auto inputOpt = getBuffer(rewriter, input, options, state);
|
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
inputs.push_back(*inputOpt);
|
inputs.push_back(*inputOpt);
|
||||||
@@ -403,13 +466,9 @@ struct MapOpInterface : BufferizableOpInterface::ExternalModel<MapOpInterface, P
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
|
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; }
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; }
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return {};
|
return {};
|
||||||
@@ -422,19 +481,18 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
|||||||
return {};
|
return {};
|
||||||
|
|
||||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
||||||
return {{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}};
|
return {
|
||||||
|
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const {
|
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<BufferLikeType>
|
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||||
getBufferType(Operation* op,
|
Value value,
|
||||||
Value value,
|
const BufferizationOptions& options,
|
||||||
const BufferizationOptions& options,
|
const BufferizationState& state,
|
||||||
const BufferizationState& state,
|
SmallVector<Value>& invocationStack) const {
|
||||||
SmallVector<Value>& invocationStack) const {
|
|
||||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||||
@@ -453,6 +511,14 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||||
|
|
||||||
|
bool alreadyBufferized =
|
||||||
|
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||||
|
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
|
||||||
|
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
|
||||||
|
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
|
||||||
|
if (alreadyBufferized)
|
||||||
|
return success();
|
||||||
|
|
||||||
SmallVector<Value> weights;
|
SmallVector<Value> weights;
|
||||||
SmallVector<Value> inputs;
|
SmallVector<Value> inputs;
|
||||||
weights.reserve(coreBatchOp.getWeights().size());
|
weights.reserve(coreBatchOp.getWeights().size());
|
||||||
@@ -460,7 +526,7 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
|||||||
|
|
||||||
for (Value weight : coreBatchOp.getWeights()) {
|
for (Value weight : coreBatchOp.getWeights()) {
|
||||||
if (isa<TensorType>(weight.getType())) {
|
if (isa<TensorType>(weight.getType())) {
|
||||||
auto weightOpt = getBuffer(rewriter, weight, options, state);
|
auto weightOpt = getBufferOrValue(rewriter, weight, options, state);
|
||||||
if (failed(weightOpt))
|
if (failed(weightOpt))
|
||||||
return failure();
|
return failure();
|
||||||
weights.push_back(*weightOpt);
|
weights.push_back(*weightOpt);
|
||||||
@@ -472,7 +538,7 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
|||||||
|
|
||||||
for (Value input : coreBatchOp.getInputs()) {
|
for (Value input : coreBatchOp.getInputs()) {
|
||||||
if (isa<TensorType>(input.getType())) {
|
if (isa<TensorType>(input.getType())) {
|
||||||
auto inputOpt = getBuffer(rewriter, input, options, state);
|
auto inputOpt = getBufferOrValue(rewriter, input, options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
inputs.push_back(*inputOpt);
|
inputs.push_back(*inputOpt);
|
||||||
@@ -510,11 +576,11 @@ struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeO
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto transposeOp = cast<PimTransposeOp>(op);
|
auto transposeOp = cast<PimTransposeOp>(op);
|
||||||
|
|
||||||
auto inputOpt = getBuffer(rewriter, transposeOp.getInput(), options, state);
|
auto inputOpt = getBufferOrValue(rewriter, transposeOp.getInput(), options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outputBufferOpt = getBuffer(rewriter, transposeOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, transposeOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -547,11 +613,11 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto vmmOp = cast<PimVMMOp>(op);
|
auto vmmOp = cast<PimVMMOp>(op);
|
||||||
|
|
||||||
auto inputOpt = getBuffer(rewriter, vmmOp.getInput(), options, state);
|
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outputBufferOpt = getBuffer(rewriter, vmmOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, vmmOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -574,11 +640,11 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface,
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto mvmOp = cast<PimMVMOp>(op);
|
auto mvmOp = cast<PimMVMOp>(op);
|
||||||
|
|
||||||
auto inputOpt = getBuffer(rewriter, mvmOp.getInput(), options, state);
|
auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outputBufferOpt = getBuffer(rewriter, mvmOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -608,15 +674,15 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto binaryOp = cast<OpTy>(op);
|
auto binaryOp = cast<OpTy>(op);
|
||||||
|
|
||||||
auto lhsOpt = getBuffer(rewriter, binaryOp.getLhs(), options, state);
|
auto lhsOpt = getBufferOrValue(rewriter, binaryOp.getLhs(), options, state);
|
||||||
if (failed(lhsOpt))
|
if (failed(lhsOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto rhsOpt = getBuffer(rewriter, binaryOp.getRhs(), options, state);
|
auto rhsOpt = getBufferOrValue(rewriter, binaryOp.getRhs(), options, state);
|
||||||
if (failed(rhsOpt))
|
if (failed(rhsOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outputBufferOpt = getBuffer(rewriter, binaryOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, binaryOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -647,11 +713,11 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
|||||||
BufferizationState& state) const {
|
BufferizationState& state) const {
|
||||||
auto unaryOp = cast<OpTy>(op);
|
auto unaryOp = cast<OpTy>(op);
|
||||||
|
|
||||||
auto inputOpt = getBuffer(rewriter, unaryOp.getInput(), options, state);
|
auto inputOpt = getBufferOrValue(rewriter, unaryOp.getInput(), options, state);
|
||||||
if (failed(inputOpt))
|
if (failed(inputOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outputBufferOpt = getBuffer(rewriter, unaryOp.getOutputBuffer(), options, state);
|
auto outputBufferOpt = getBufferOrValue(rewriter, unaryOp.getOutputBuffer(), options, state);
|
||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -664,12 +730,15 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
|||||||
|
|
||||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||||
|
PimEmptyManyOp::attachInterface<EmptyManyOpInterface>(*ctx);
|
||||||
PimMapOp::attachInterface<MapOpInterface>(*ctx);
|
PimMapOp::attachInterface<MapOpInterface>(*ctx);
|
||||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
||||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||||
PimReceiveManyOp::attachInterface<ReceiveManyOpInterface>(*ctx);
|
PimReceiveManyOp::attachInterface<ReceiveManyOpInterface>(*ctx);
|
||||||
|
PimReceiveTensorOp::attachInterface<ReceiveTensorOpInterface>(*ctx);
|
||||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
||||||
PimReceiveManyBatchOp::attachInterface<ReceiveManyBatchOpInterface>(*ctx);
|
PimReceiveManyBatchOp::attachInterface<ReceiveManyBatchOpInterface>(*ctx);
|
||||||
|
PimSendTensorOp::attachInterface<SendTensorOpInterface>(*ctx);
|
||||||
PimExtractRowsOp::attachInterface<ExtractRowsOpInterface>(*ctx);
|
PimExtractRowsOp::attachInterface<ExtractRowsOpInterface>(*ctx);
|
||||||
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
PimConcatOp::attachInterface<ConcatOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
|
|||||||
@@ -47,37 +47,26 @@ private:
|
|||||||
|
|
||||||
void PimBufferizationPass::runOnOperation() {
|
void PimBufferizationPass::runOnOperation() {
|
||||||
auto moduleOp = getOperation();
|
auto moduleOp = getOperation();
|
||||||
{
|
|
||||||
SmallVector<pim::PimEmptyManyOp> emptyManyOps;
|
|
||||||
moduleOp.walk([&](pim::PimEmptyManyOp emptyManyOp) { emptyManyOps.push_back(emptyManyOp); });
|
|
||||||
|
|
||||||
IRRewriter rewriter(moduleOp.getContext());
|
|
||||||
for (auto emptyManyOp : emptyManyOps) {
|
|
||||||
SmallVector<Value> replacementValues;
|
|
||||||
replacementValues.reserve(emptyManyOp.getOutputs().size());
|
|
||||||
rewriter.setInsertionPoint(emptyManyOp);
|
|
||||||
for (Value output : emptyManyOp.getOutputs()) {
|
|
||||||
auto outputType = cast<RankedTensorType>(output.getType());
|
|
||||||
replacementValues.push_back(
|
|
||||||
tensor::EmptyOp::create(rewriter, emptyManyOp.getLoc(), outputType.getShape(), outputType.getElementType()));
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(emptyManyOp, replacementValues);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Refactor this into a function
|
// Refactor this into a function
|
||||||
{
|
{
|
||||||
auto funcOp = getPimEntryFunc(moduleOp);
|
auto funcOp = *getPimEntryFunc(moduleOp);
|
||||||
|
|
||||||
auto coreOps = llvm::to_vector(funcOp->getOps<pim::PimCoreOp>());
|
SmallVector<Operation*> coreOps;
|
||||||
|
funcOp->walk<WalkOrder::PreOrder>([&](Operation* op) {
|
||||||
|
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
||||||
|
coreOps.push_back(op);
|
||||||
|
});
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
|
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
|
||||||
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) {
|
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](Operation* coreOp) {
|
||||||
// Again, allocate state LOCALLY per thread/function
|
// Again, allocate state LOCALLY per thread/function
|
||||||
bufferization::OneShotBufferizationOptions options;
|
bufferization::OneShotBufferizationOptions options;
|
||||||
options.allowUnknownOps = true;
|
options.allowUnknownOps = true;
|
||||||
|
if (isa<pim::PimCoreBatchOp>(coreOp))
|
||||||
|
options.opFilter.denyOperation([coreOp](Operation* op) { return op == coreOp; });
|
||||||
bufferization::BufferizationState state;
|
bufferization::BufferizationState state;
|
||||||
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
|
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
|
||||||
coreOp.emitError("Failed to bufferize PIM and Spatial ops");
|
coreOp->emitError("Failed to bufferize PIM and Spatial ops");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
@@ -89,13 +78,16 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
|
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
|
||||||
if (llvm::isa_and_present<pim::PimCoreOp>(toTensorOp->getParentOp()))
|
if (llvm::isa_and_present<pim::PimCoreOp, pim::PimCoreBatchOp>(toTensorOp->getParentOp()))
|
||||||
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
|
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
|
||||||
});
|
});
|
||||||
|
|
||||||
// One-Shot-Bufferization
|
// One-Shot-Bufferization
|
||||||
bufferization::OneShotBufferizationOptions options;
|
bufferization::OneShotBufferizationOptions options;
|
||||||
options.allowUnknownOps = true;
|
options.allowUnknownOps = true;
|
||||||
|
options.opFilter.denyOperation([](Operation* op) {
|
||||||
|
return op->getParentOfType<pim::PimCoreOp>() || op->getParentOfType<pim::PimCoreBatchOp>();
|
||||||
|
});
|
||||||
bufferization::BufferizationState state;
|
bufferization::BufferizationState state;
|
||||||
|
|
||||||
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
|
|||||||
// Math
|
// Math
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def SpatWeightedVMMOp : SpatOp<"wvmm", []> {
|
def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
@@ -272,7 +272,7 @@ def SpatWeightedVMMOp : SpatOp<"wvmm", []> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
|
def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ namespace spatial {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
|
inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter,
|
||||||
ArrayRef<int64_t>& matrixShape,
|
ArrayRef<int64_t>& matrixShape,
|
||||||
ArrayRef<int64_t>& vectorShape,
|
ArrayRef<int64_t>& vectorShape,
|
||||||
ArrayRef<int64_t>& outputShape) {
|
ArrayRef<int64_t>& outputShape) {
|
||||||
@@ -45,7 +45,7 @@ inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
||||||
ArrayRef<int64_t>& matrixShape,
|
ArrayRef<int64_t>& matrixShape,
|
||||||
ArrayRef<int64_t>& vectorShape,
|
ArrayRef<int64_t>& vectorShape,
|
||||||
ArrayRef<int64_t>& outputShape) {
|
ArrayRef<int64_t>& outputShape) {
|
||||||
@@ -177,10 +177,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto& bodyOp : block) {
|
for (auto& bodyOp : block) {
|
||||||
if (auto wvmm = dyn_cast<SpatWeightedVMMOp>(&bodyOp))
|
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
|
||||||
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
||||||
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
||||||
if (auto wmvm = dyn_cast<SpatWeightedMVMOp>(&bodyOp))
|
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
|
||||||
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
||||||
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
||||||
}
|
}
|
||||||
@@ -189,10 +189,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult SpatWeightedMVMOp::verify() {
|
LogicalResult SpatMVMOp::verify() {
|
||||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
|
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getInput().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
@@ -204,10 +204,10 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
|||||||
return emitError("matrix rank must be 2 or 4");
|
return emitError("matrix rank must be 2 or 4");
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult SpatWeightedVMMOp::verify() {
|
LogicalResult SpatVMMOp::verify() {
|
||||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||||
if (failed(matrixShapeOpt))
|
if (failed(matrixShapeOpt))
|
||||||
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
|
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
|
||||||
auto matrixShape = *matrixShapeOpt;
|
auto matrixShape = *matrixShapeOpt;
|
||||||
auto vectorShape = getInput().getType().getShape();
|
auto vectorShape = getInput().getType().getShape();
|
||||||
auto outputShape = getOutput().getType().getShape();
|
auto outputShape = getOutput().getType().getShape();
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
|||||||
CrossbarUsage crossbarUsage = 0;
|
CrossbarUsage crossbarUsage = 0;
|
||||||
for (auto& block : body)
|
for (auto& block : body)
|
||||||
for (auto& op : block)
|
for (auto& op : block)
|
||||||
if (isa<SpatWeightedVMMOp>(op))
|
if (isa<SpatVMMOp>(op))
|
||||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||||
return crossbarUsage;
|
return crossbarUsage;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute
|
|||||||
CrossbarUsage crossbarUsage = 0;
|
CrossbarUsage crossbarUsage = 0;
|
||||||
for (auto& region : spatCompute.getBody())
|
for (auto& region : spatCompute.getBody())
|
||||||
for (auto& inst : region)
|
for (auto& inst : region)
|
||||||
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
|
if (llvm::isa<onnx_mlir::spatial::SpatVMMOp>(inst))
|
||||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||||
return crossbarUsage;
|
return crossbarUsage;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -838,9 +838,9 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
for (auto& op : child.getBody().front()) {
|
for (auto& op : child.getBody().front()) {
|
||||||
auto newInst = rewriter.clone(op, mapper);
|
auto newInst = rewriter.clone(op, mapper);
|
||||||
if (auto weightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(newInst))
|
if (auto weightedMvmOp = dyn_cast<spatial::SpatMVMOp>(newInst))
|
||||||
remapWeightIndex(weightedMvmOp);
|
remapWeightIndex(weightedMvmOp);
|
||||||
if (auto weightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(newInst))
|
if (auto weightedVmmOp = dyn_cast<spatial::SpatVMMOp>(newInst))
|
||||||
remapWeightIndex(weightedVmmOp);
|
remapWeightIndex(weightedVmmOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -884,9 +884,9 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
|||||||
ComputeMotifInfo& info = computeInfos[index];
|
ComputeMotifInfo& info = computeInfos[index];
|
||||||
for (Operation& op : compute.getBody().front()) {
|
for (Operation& op : compute.getBody().front()) {
|
||||||
info.instructionCount++;
|
info.instructionCount++;
|
||||||
if (isa<spatial::SpatWeightedMVMOp>(&op))
|
if (isa<spatial::SpatMVMOp>(&op))
|
||||||
info.weightedMvmCount++;
|
info.weightedMvmCount++;
|
||||||
if (isa<spatial::SpatWeightedVMMOp>(&op))
|
if (isa<spatial::SpatVMMOp>(&op))
|
||||||
info.weightedVmmCount++;
|
info.weightedVmmCount++;
|
||||||
}
|
}
|
||||||
if (info.weightedVmmCount > 0) {
|
if (info.weightedVmmCount > 0) {
|
||||||
@@ -1617,13 +1617,13 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||||
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||||
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
}
|
}
|
||||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||||
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||||
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
}
|
}
|
||||||
@@ -1643,22 +1643,22 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||||
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||||
}
|
}
|
||||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||||
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
|
|||||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) {
|
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||||
RegularChunk chunk;
|
RegularChunk chunk;
|
||||||
chunk.startOp = startOp.getOperation();
|
chunk.startOp = startOp.getOperation();
|
||||||
chunk.input = startOp.getInput();
|
chunk.input = startOp.getInput();
|
||||||
@@ -376,7 +376,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
auto compactInBlock = [&](Block& block) {
|
auto compactInBlock = [&](Block& block) {
|
||||||
for (auto it = block.begin(); it != block.end();) {
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
auto startOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
auto startOp = dyn_cast<spatial::SpatVMMOp>(&*it);
|
||||||
if (!startOp) {
|
if (!startOp) {
|
||||||
++it;
|
++it;
|
||||||
continue;
|
continue;
|
||||||
@@ -391,7 +391,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
SmallVector<RegularChunk> run {*anchorChunk};
|
SmallVector<RegularChunk> run {*anchorChunk};
|
||||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||||
while (runIt != block.end()) {
|
while (runIt != block.end()) {
|
||||||
auto candidateStart = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||||
if (!candidateStart)
|
if (!candidateStart)
|
||||||
break;
|
break;
|
||||||
|
|
||||||
@@ -425,7 +425,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
Block& block = compute.getBody().front();
|
Block& block = compute.getBody().front();
|
||||||
for (auto it = block.begin(); it != block.end();) {
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
auto wvmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
auto wvmmOp = dyn_cast<spatial::SpatVMMOp>(&*it);
|
||||||
if (!wvmmOp) {
|
if (!wvmmOp) {
|
||||||
++it;
|
++it;
|
||||||
continue;
|
continue;
|
||||||
@@ -440,11 +440,11 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<spatial::SpatWeightedVMMOp> run;
|
SmallVector<spatial::SpatVMMOp> run;
|
||||||
auto runIt = it;
|
auto runIt = it;
|
||||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||||
while (runIt != block.end()) {
|
while (runIt != block.end()) {
|
||||||
auto current = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||||
@@ -545,7 +545,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
extractOffsets,
|
extractOffsets,
|
||||||
extractSizes,
|
extractSizes,
|
||||||
extractStrides);
|
extractStrides);
|
||||||
auto loopWvmm = spatial::SpatWeightedVMMOp::create(
|
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||||
|
|
||||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ namespace {
|
|||||||
|
|
||||||
static bool isAddressOnlyHostOp(Operation* op) {
|
static bool isAddressOnlyHostOp(Operation* op) {
|
||||||
return isa<arith::ConstantOp,
|
return isa<arith::ConstantOp,
|
||||||
|
pim::PimEmptyManyOp,
|
||||||
memref::AllocOp,
|
memref::AllocOp,
|
||||||
memref::GetGlobalOp,
|
memref::GetGlobalOp,
|
||||||
memref::SubViewOp,
|
memref::SubViewOp,
|
||||||
@@ -36,7 +37,7 @@ static bool isBaseAddressableValue(Value value) {
|
|||||||
Operation* defOp = value.getDefiningOp();
|
Operation* defOp = value.getDefiningOp();
|
||||||
if (!defOp)
|
if (!defOp)
|
||||||
return false;
|
return false;
|
||||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
|
if (isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(defOp))
|
||||||
return true;
|
return true;
|
||||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
|
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
|
||||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
|
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
|
||||||
@@ -51,7 +52,7 @@ static bool isCodegenAddressableValue(Value value) {
|
|||||||
if (failed(resolvedAddress))
|
if (failed(resolvedAddress))
|
||||||
return false;
|
return false;
|
||||||
return isa<BlockArgument>(resolvedAddress->base)
|
return isa<BlockArgument>(resolvedAddress->base)
|
||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|| isa<pim::PimEmptyManyOp, memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
@@ -184,7 +185,7 @@ private:
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
if (!isa<pim::PimEmptyManyOp, memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||||
op.emitOpError() << "operand #" << operandIndex
|
op.emitOpError() << "operand #" << operandIndex
|
||||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
|
|||||||
Reference in New Issue
Block a user