This commit is contained in:
@@ -98,7 +98,7 @@ Supporting pieces:
|
||||
core count, DCP window, experimental conv impl, concat error handling, …)
|
||||
and `PimCodeGen` entry points.
|
||||
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
|
||||
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
|
||||
- `src/PIM/Pass` — auxiliary passes (`MessagePass`)
|
||||
and the `PIMPasses.h` registry used by `PimAccelerator`.
|
||||
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
|
||||
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
|
||||
|
||||
@@ -19,23 +19,21 @@ void markWeightAlways(mlir::Operation* op) {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeight() == *weightArg;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeight() == *weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
@@ -49,7 +47,6 @@ void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpO
|
||||
}
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||
}
|
||||
|
||||
@@ -63,7 +60,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
|
||||
return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||
|
||||
@@ -30,20 +30,17 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
if (pimEmissionTarget >= EmitSpatial) {
|
||||
pm.addPass(createONNXToSpatialPass());
|
||||
pm.addPass(createMergeComputeNodesPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPim) {
|
||||
pm.addPass(createSpatialToPimPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||
pm.addPass(createPimBufferizationPass());
|
||||
pm.addPass(createPimStaticMemoryCoalescingPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim bufferized"));
|
||||
}
|
||||
|
||||
@@ -54,7 +51,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimCodePass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim code emitted"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,16 +74,6 @@ def PimHaltOp : PimOp<"halt", [Terminator]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimYieldOp : PimOp<"yield", [Terminator]> {
|
||||
let summary = "Yield results from a Pim region";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<PimTensor>:$outputs
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Communication
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -227,40 +227,6 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void PimYieldOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printCompressedValueSequence(printer, getOutputs());
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
printer << " : ";
|
||||
printCompressedTypeSequence(printer, getOutputs().getTypes());
|
||||
}
|
||||
|
||||
ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> outputs;
|
||||
SmallVector<Type> outputTypes;
|
||||
|
||||
OpAsmParser::UnresolvedOperand firstOutput;
|
||||
OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput);
|
||||
if (firstOutputResult.has_value()) {
|
||||
if (failed(*firstOutputResult))
|
||||
return failure();
|
||||
if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedOperandEntry(parser, outputs))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true))
|
||||
return failure();
|
||||
|
||||
if (outputs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match");
|
||||
|
||||
return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands);
|
||||
}
|
||||
|
||||
void PimConcatOp::print(OpAsmPrinter& printer) {
|
||||
printer << " axis " << getAxis() << " ";
|
||||
printCompressedValueSequence(printer, getInputs());
|
||||
|
||||
@@ -2,7 +2,6 @@ add_onnx_mlir_dialect(Spatial spat)
|
||||
add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||
|
||||
add_pim_library(SpatialOps
|
||||
Channels.cpp
|
||||
SpatialOps.cpp
|
||||
SpatialOpsAsm.cpp
|
||||
SpatialOpsVerify.cpp
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
static FailureOr<int64_t> getConstantI64(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getConstantI32(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
|
||||
return getConstantI64(sendOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI64(receiveOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getSourceCoreId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getTargetCoreId());
|
||||
}
|
||||
|
||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
||||
if (!endpoints.send || !endpoints.receive)
|
||||
return failure();
|
||||
|
||||
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
|
||||
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendSourceCoreId != *receiveSourceCoreId) {
|
||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
|
||||
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
|
||||
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendTargetCoreId != *receiveTargetCoreId) {
|
||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
|
||||
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Channels::Channels(func::FuncOp funcOp) {
|
||||
if (!funcOp)
|
||||
return;
|
||||
|
||||
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
|
||||
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
|
||||
}
|
||||
|
||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
||||
|
||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].send = sendOp;
|
||||
}
|
||||
|
||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].receive = receiveOp;
|
||||
}
|
||||
|
||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.send = {};
|
||||
if (!it->second.receive)
|
||||
endpoints.erase(it);
|
||||
}
|
||||
|
||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.receive = {};
|
||||
if (!it->second.send)
|
||||
endpoints.erase(it);
|
||||
}
|
||||
|
||||
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
||||
auto it = endpoints.find(id);
|
||||
if (it == endpoints.end())
|
||||
return failure();
|
||||
return it->second;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
||||
return failure();
|
||||
return endpointsOr->receive;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->send)
|
||||
return failure();
|
||||
return endpointsOr->send;
|
||||
}
|
||||
|
||||
LogicalResult Channels::verify() const {
|
||||
for (const auto& [channelId, pair] : endpoints) {
|
||||
if (!pair.send || !pair.receive) {
|
||||
if (pair.send) {
|
||||
auto sendOp = pair.send;
|
||||
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
|
||||
}
|
||||
else if (pair.receive) {
|
||||
auto receiveOp = pair.receive;
|
||||
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (failed(verifyEndpointPair(pair)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -1,43 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir::spatial {
|
||||
|
||||
struct ChannelEndpoints {
|
||||
SpatChannelSendOp send;
|
||||
SpatChannelReceiveOp receive;
|
||||
};
|
||||
|
||||
class Channels {
|
||||
public:
|
||||
using ChannelId = int64_t;
|
||||
|
||||
explicit Channels(mlir::func::FuncOp funcOp);
|
||||
|
||||
ChannelId allocate();
|
||||
|
||||
void insertSend(SpatChannelSendOp sendOp);
|
||||
void insertReceive(SpatChannelReceiveOp receiveOp);
|
||||
void eraseSend(SpatChannelSendOp sendOp);
|
||||
void eraseReceive(SpatChannelReceiveOp receiveOp);
|
||||
|
||||
llvm::FailureOr<ChannelEndpoints> lookup(ChannelId id) const;
|
||||
llvm::FailureOr<SpatChannelReceiveOp> getReceiveFor(SpatChannelSendOp sendOp) const;
|
||||
llvm::FailureOr<SpatChannelSendOp> getSendFor(SpatChannelReceiveOp receiveOp) const;
|
||||
|
||||
mlir::LogicalResult verify() const;
|
||||
|
||||
private:
|
||||
ChannelId nextChannelId = 0;
|
||||
llvm::DenseMap<ChannelId, ChannelEndpoints> endpoints;
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir::spatial
|
||||
@@ -217,25 +217,6 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$weight,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVAddOp : SpatOp<"vadd", []> {
|
||||
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
@@ -272,22 +253,6 @@ def SpatVMulOp : SpatOp<"vmul", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSumOp : SpatOp<"sum", []> {
|
||||
let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVAvgOp : SpatOp<"vavg", []> {
|
||||
let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor";
|
||||
|
||||
|
||||
@@ -26,68 +26,6 @@ namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter,
|
||||
ArrayRef<int64_t>& matrixShape,
|
||||
ArrayRef<int64_t>& vectorShape,
|
||||
ArrayRef<int64_t>& outputShape) {
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitter->emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
|
||||
int64_t vectorM = vectorShape[0];
|
||||
int64_t vector1 = vectorShape[1];
|
||||
if (vectorM != M || vector1 != 1)
|
||||
return emitter->emitError("vector shape must be (M, 1)");
|
||||
|
||||
int64_t outputN = outputShape[0];
|
||||
int64_t output1 = outputShape[1];
|
||||
if (outputN != N || output1 != 1)
|
||||
return emitter->emitError("output shape must be (N, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
||||
ArrayRef<int64_t>& matrixShape,
|
||||
ArrayRef<int64_t>& vectorShape,
|
||||
ArrayRef<int64_t>& outputShape) {
|
||||
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
|
||||
return emitter->emitError("matrix, vector and output must have rank 4");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
int64_t matrix1First = matrixShape[2];
|
||||
int64_t matrix1Second = matrixShape[3];
|
||||
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
|
||||
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
|
||||
|
||||
int64_t vector1First = vectorShape[0];
|
||||
int64_t vectorM = vectorShape[1];
|
||||
int64_t vector1Second = vectorShape[2];
|
||||
int64_t vector1Third = vectorShape[3];
|
||||
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
|
||||
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
|
||||
// This is ok, it was caused by the simplification of the concat error.
|
||||
}
|
||||
else {
|
||||
return emitter->emitError("vector shape must be (1, M, 1, 1)");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t output1First = outputShape[0];
|
||||
int64_t outputN = outputShape[1];
|
||||
int64_t output1Second = outputShape[2];
|
||||
int64_t output1Third = outputShape[3];
|
||||
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
|
||||
return emitter->emitError("output shape must be (1, N, 1, 1)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
@@ -287,21 +225,6 @@ static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult SpatMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
if (matrixShape.size() == 2)
|
||||
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
|
||||
if (matrixShape.size() == 4)
|
||||
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
|
||||
return emitError("matrix rank must be 2 or 4");
|
||||
}
|
||||
|
||||
LogicalResult SpatVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
|
||||
@@ -136,7 +136,6 @@ static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
|
||||
struct ComputeMotifInfo {
|
||||
uint64_t instructionCount = 0;
|
||||
uint64_t weightedMvmCount = 0;
|
||||
uint64_t weightedVmmCount = 0;
|
||||
};
|
||||
|
||||
@@ -285,8 +284,6 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
||||
ComputeMotifInfo& info = computeInfos[index];
|
||||
for (Operation& op : compute.getBody().front()) {
|
||||
info.instructionCount++;
|
||||
if (isa<spatial::SpatMVMOp>(&op))
|
||||
info.weightedMvmCount++;
|
||||
if (isa<spatial::SpatVMMOp>(&op))
|
||||
info.weightedVmmCount++;
|
||||
}
|
||||
@@ -400,7 +397,7 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
||||
wideWeightedVmmLevels256 += count >= 256;
|
||||
}
|
||||
|
||||
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
|
||||
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
|
||||
SmallVector<ShapeKey> weightedVmmShapeKeys;
|
||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||
const ComputeMotifInfo& info = computeInfos[index];
|
||||
@@ -408,7 +405,6 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
||||
continue;
|
||||
weightedVmmShapeKeys.push_back({info.instructionCount,
|
||||
info.weightedVmmCount,
|
||||
info.weightedMvmCount,
|
||||
static_cast<uint64_t>(compute.getWeights().size()),
|
||||
static_cast<uint64_t>(compute.getInputs().size()),
|
||||
static_cast<uint64_t>(parents[index].size()),
|
||||
@@ -461,14 +457,13 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
||||
|
||||
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
|
||||
auto [count, shape] = weightedVmmShapeCounts[rank];
|
||||
auto [insts, vmmOps, mvmOps, weights, inputs, fanIn, fanOut] = shape;
|
||||
auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape;
|
||||
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
|
||||
"mvmOps={4} weights={5} inputs={6} fanIn={7} fanOut={8}\n",
|
||||
"weights={4} inputs={5} fanIn={6} fanOut={7}\n",
|
||||
rank,
|
||||
count,
|
||||
insts,
|
||||
vmmOps,
|
||||
mvmOps,
|
||||
weights,
|
||||
inputs,
|
||||
fanIn,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
add_pim_library(OMPimPasses
|
||||
CountInstructionPass.cpp
|
||||
MessagePass.cpp
|
||||
PimCodegen/HostConstantFolding/Common.cpp
|
||||
PimCodegen/HostConstantFolding/Patterns/Constant.cpp
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct CountInstructionPass : public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass)
|
||||
|
||||
StringRef getArgument() const override { return "count-instruction-pass"; }
|
||||
|
||||
StringRef getDescription() const override { return "Count instructions for each core/compute in the module"; }
|
||||
|
||||
// Make sure that we have a valid default constructor and copy
|
||||
// constructor to make sure that the options are initialized properly.
|
||||
CountInstructionPass() {}
|
||||
CountInstructionPass(const CountInstructionPass& pass)
|
||||
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
|
||||
void runOnOperation() final {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
func::FuncOp func = *module.getOps<func::FuncOp>().begin();
|
||||
|
||||
unsigned totalInstructionCount = 0;
|
||||
|
||||
unsigned computeId = 0;
|
||||
for (auto computeOp : func.getOps<spatial::SpatCompute>()) {
|
||||
unsigned instructionCount = 0;
|
||||
instructionCount += computeOp.getBody().front().getOperations().size();
|
||||
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
|
||||
totalInstructionCount += instructionCount;
|
||||
computeId++;
|
||||
}
|
||||
|
||||
unsigned coreId = 0;
|
||||
for (auto coreOp : func.getOps<pim::PimCoreOp>()) {
|
||||
unsigned instructionCount = 0;
|
||||
instructionCount += coreOp.getBody().front().getOperations().size();
|
||||
llvm::outs() << "Core " << coreId << ": " << instructionCount << " instructions\n";
|
||||
totalInstructionCount += instructionCount;
|
||||
coreId++;
|
||||
}
|
||||
|
||||
llvm::outs() << "Total instruction count: " << totalInstructionCount << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createCountInstructionPass() { return std::make_unique<CountInstructionPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -29,6 +29,4 @@ std::unique_ptr<mlir::Pass> createEmitPimCodePass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
|
||||
|
||||
std::unique_ptr<mlir::Pass> createCountInstructionPass();
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user