compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
This commit is contained in:
@@ -39,6 +39,22 @@ def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Execute equivalent batched core bodies";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$laneCount,
|
||||
Variadic<PimTensor>:$weights,
|
||||
Variadic<PimTensor>:$inputs
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`lanes` $laneCount `(` $weights `)` `[` $inputs `]` attr-dict regions `:` type($weights) `[` type($inputs) `]` `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def PimHaltOp : PimOp<"halt", [Terminator]> {
|
||||
let summary = "Halt execution of the core";
|
||||
|
||||
@@ -65,6 +81,20 @@ def PimSendOp : PimOp<"send", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimSendBatchOp : PimOp<"send_batch", []> {
|
||||
let summary = "Send a per-lane tensor to target cores from a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
I32Attr:$size,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive a tensor from another core";
|
||||
|
||||
@@ -89,6 +119,30 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Receive per-lane tensors from source cores into a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
I32Attr:$size,
|
||||
DenseI32ArrayAttr:$sourceCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
@@ -115,6 +169,32 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$deviceTarget,
|
||||
PimTensor:$hostSource,
|
||||
I32Attr:$deviceTargetOffset,
|
||||
I32Attr:$hostSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getDeviceTargetMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from device memory into host memory";
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "OpBufferizationInterfaces.hpp"
|
||||
@@ -65,6 +66,32 @@ struct MemCopyHostToDevOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyHostToDevBatchOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevBatchOpInterface, PimMemCopyHostToDevBatchOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
|
||||
auto deviceTargetOpt = getBuffer(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
|
||||
if (failed(deviceTargetOpt))
|
||||
return failure();
|
||||
auto hostSourceOpt = getBuffer(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
|
||||
if (failed(hostSourceOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevBatchOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceTargetOpt->getType(),
|
||||
*deviceTargetOpt,
|
||||
*hostSourceOpt,
|
||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyDevToHostOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -122,6 +149,127 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto receiveOp = cast<PimReceiveBatchOp>(op);
|
||||
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdsAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return {};
|
||||
|
||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
||||
return {{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}};
|
||||
}
|
||||
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
FailureOr<BufferLikeType>
|
||||
getBufferType(Operation* op,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
const BufferizationState& state,
|
||||
SmallVector<Value>& invocationStack) const {
|
||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return failure();
|
||||
|
||||
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
|
||||
return memRefType;
|
||||
|
||||
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
||||
|
||||
SmallVector<Value> weights;
|
||||
SmallVector<Value> inputs;
|
||||
weights.reserve(coreBatchOp.getWeights().size());
|
||||
inputs.reserve(coreBatchOp.getInputs().size());
|
||||
|
||||
for (Value weight : coreBatchOp.getWeights()) {
|
||||
if (isa<TensorType>(weight.getType())) {
|
||||
auto weightOpt = getBuffer(rewriter, weight, options, state);
|
||||
if (failed(weightOpt))
|
||||
return failure();
|
||||
weights.push_back(*weightOpt);
|
||||
}
|
||||
else {
|
||||
weights.push_back(weight);
|
||||
}
|
||||
}
|
||||
|
||||
for (Value input : coreBatchOp.getInputs()) {
|
||||
if (isa<TensorType>(input.getType())) {
|
||||
auto inputOpt = getBuffer(rewriter, input, options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
inputs.push_back(*inputOpt);
|
||||
}
|
||||
else {
|
||||
inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(coreBatchOp);
|
||||
auto newOp = PimCoreBatchOp::create(
|
||||
rewriter, coreBatchOp.getLoc(), coreBatchOp.getLaneCountAttr(), ValueRange(weights), ValueRange(inputs));
|
||||
newOp.getProperties().setOperandSegmentSizes({static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||
if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName))
|
||||
newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr);
|
||||
|
||||
rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin());
|
||||
for (Block& block : newOp.getBody())
|
||||
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
|
||||
return failure();
|
||||
|
||||
rewriter.eraseOp(coreBatchOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -287,8 +435,11 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
||||
|
||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
|
||||
@@ -93,8 +93,8 @@ void PimBufferizationPass::runOnOperation() {
|
||||
}
|
||||
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](PimCoreOp coreOp) {
|
||||
walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) {
|
||||
auto markWeights = [&](Operation* op) {
|
||||
walkPimMvmVmmWeightUses(op, [&](OpOperand& weightUse) {
|
||||
Value weight = weightUse.get();
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
@@ -104,7 +104,10 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
||||
markWeightAlways(getGlobalOp);
|
||||
markWeightAlways(globalMemrefOp);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
funcOp.walk([&](PimCoreOp coreOp) { markWeights(coreOp); });
|
||||
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||
|
||||
Reference in New Issue
Block a user