628dc630a4
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge remove pim.mvm op better memory report
528 lines
20 KiB
C++
528 lines
20 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinDialect.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
|
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/Casting.h"
|
|
|
|
#include <cassert>
|
|
#include <utility>
|
|
|
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
|
|
using namespace mlir;
|
|
using namespace onnx_mlir;
|
|
using namespace pim;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
|
|
|
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
|
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
|
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
|
|
|
SpatialToPimPass() = default;
|
|
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
|
|
|
void runOnOperation() final;
|
|
|
|
private:
|
|
SmallVector<OutputTensorFactory> outputTensors;
|
|
size_t coreId = 0;
|
|
SmallVector<Operation*> operationsToRemove;
|
|
|
|
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
|
|
void markOpToRemove(Operation* op);
|
|
|
|
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
|
|
|
static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) {
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
|
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId()));
|
|
|
|
rewriter.setInsertionPoint(sendOp);
|
|
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
|
rewriter.eraseOp(sendOp);
|
|
}
|
|
|
|
static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
|
if (receiveOp->use_empty()) {
|
|
rewriter.eraseOp(receiveOp);
|
|
return;
|
|
}
|
|
|
|
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
|
rewriter.setInsertionPoint(receiveOp);
|
|
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
|
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
|
|
|
Value received =
|
|
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
|
.getOutput();
|
|
rewriter.replaceOp(receiveOp, received);
|
|
}
|
|
|
|
static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp, IRRewriter& rewriter) {
|
|
SmallVector<int32_t> targetCoreIds;
|
|
targetCoreIds.reserve(sendTensorOp.getTargetCoreIds().size());
|
|
for (int32_t targetCoreId : sendTensorOp.getTargetCoreIds())
|
|
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
|
|
|
rewriter.setInsertionPoint(sendTensorOp);
|
|
PimSendTensorOp::create(
|
|
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
|
rewriter.eraseOp(sendTensorOp);
|
|
}
|
|
|
|
static void lowerChannelReceiveTensor(spatial::SpatChannelReceiveTensorOp receiveTensorOp, IRRewriter& rewriter) {
|
|
SmallVector<int32_t> sourceCoreIds;
|
|
sourceCoreIds.reserve(receiveTensorOp.getSourceCoreIds().size());
|
|
for (int32_t sourceCoreId : receiveTensorOp.getSourceCoreIds())
|
|
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
|
|
|
rewriter.setInsertionPoint(receiveTensorOp);
|
|
auto outputType = cast<ShapedType>(receiveTensorOp.getOutput().getType());
|
|
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType).getResult();
|
|
Value received = PimReceiveTensorOp::create(rewriter,
|
|
receiveTensorOp.getLoc(),
|
|
receiveTensorOp.getOutput().getType(),
|
|
outputBuffer,
|
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
|
.getOutput();
|
|
rewriter.replaceOp(receiveTensorOp, received);
|
|
}
|
|
|
|
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
|
rewriter.setInsertionPoint(extractRowsOp);
|
|
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(extractRowsOp.getNumResults());
|
|
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
|
auto outputType = cast<RankedTensorType>(output.getType());
|
|
SmallVector<OpFoldResult> offsets = {
|
|
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
|
|
rewriter.getIndexAttr(inputType.getDimSize(1))};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
replacements.push_back(
|
|
tensor::ExtractSliceOp::create(
|
|
rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides)
|
|
.getResult());
|
|
}
|
|
rewriter.replaceOp(extractRowsOp, replacements);
|
|
}
|
|
|
|
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
SmallVector<spatial::SpatConcatOp> concatOps;
|
|
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
|
for (auto concatOp : concatOps) {
|
|
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
|
continue;
|
|
|
|
SmallVector<Value> packedInputs;
|
|
bool changed = false;
|
|
rewriter.setInsertionPoint(concatOp);
|
|
|
|
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
|
Value input = concatOp.getInputs()[index];
|
|
|
|
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
|
|
unsigned endIndex = index + 1;
|
|
while (endIndex < concatOp.getInputs().size()
|
|
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
|
|
++endIndex;
|
|
|
|
Value packedInput = createPackedExtractSliceTensor(
|
|
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
|
|
if (packedInput) {
|
|
packedInputs.push_back(packedInput);
|
|
changed = true;
|
|
index = endIndex;
|
|
continue;
|
|
}
|
|
}
|
|
|
|
auto result = dyn_cast<OpResult>(input);
|
|
if (!result) {
|
|
packedInputs.push_back(input);
|
|
++index;
|
|
continue;
|
|
}
|
|
|
|
Operation* owner = result.getOwner();
|
|
unsigned startIndex = result.getResultNumber();
|
|
unsigned endIndex = index + 1;
|
|
while (endIndex < concatOp.getInputs().size()) {
|
|
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
|
if (!nextResult || nextResult.getOwner() != owner
|
|
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
|
break;
|
|
++endIndex;
|
|
}
|
|
|
|
unsigned count = endIndex - index;
|
|
Value packedInput;
|
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
|
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
|
|
|
if (packedInput) {
|
|
packedInputs.push_back(packedInput);
|
|
changed = true;
|
|
}
|
|
else {
|
|
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
|
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
|
}
|
|
|
|
index = endIndex;
|
|
}
|
|
|
|
if (!changed)
|
|
continue;
|
|
|
|
auto newConcat = pim::PimConcatOp::create(
|
|
rewriter,
|
|
concatOp.getLoc(),
|
|
concatOp.getOutput().getType(),
|
|
concatOp.getAxisAttr(),
|
|
ValueRange(packedInputs),
|
|
createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast<ShapedType>(concatOp.getOutput().getType()))
|
|
.getResult());
|
|
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
|
}
|
|
auto eraseUnusedOps = [&](auto tag) {
|
|
using OpTy = decltype(tag);
|
|
SmallVector<OpTy> ops;
|
|
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
|
for (auto op : llvm::reverse(ops))
|
|
if (op->use_empty())
|
|
rewriter.eraseOp(op);
|
|
};
|
|
eraseUnusedOps(tensor::ConcatOp {});
|
|
eraseUnusedOps(tensor::ExtractSliceOp {});
|
|
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
|
}
|
|
|
|
void SpatialToPimPass::runOnOperation() {
|
|
coreId = 1;
|
|
ModuleOp moduleOp = getOperation();
|
|
MLIRContext* ctx = moduleOp.getContext();
|
|
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
func::FuncOp funcOp = *entryFunc;
|
|
|
|
IRRewriter rewriter(&getContext());
|
|
|
|
ConversionTarget target(*ctx);
|
|
target.addLegalDialect<PimDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
bufferization::BufferizationDialect,
|
|
func::FuncDialect,
|
|
memref::MemRefDialect,
|
|
scf::SCFDialect,
|
|
BuiltinDialect>();
|
|
target.addLegalOp<spatial::SpatConcatOp,
|
|
spatial::SpatChannelReceiveOp,
|
|
spatial::SpatChannelReceiveTensorOp,
|
|
spatial::SpatChannelReceiveTensorBatchOp,
|
|
spatial::SpatChannelSendOp,
|
|
spatial::SpatChannelSendTensorOp,
|
|
spatial::SpatChannelSendTensorBatchOp,
|
|
spatial::SpatExtractRowsOp>();
|
|
|
|
{
|
|
RewritePatternSet patterns(ctx);
|
|
populateWithGenerated(patterns);
|
|
|
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
{
|
|
RewritePatternSet patterns(ctx);
|
|
populateGlobalTensorMaterializationPatterns(patterns);
|
|
|
|
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
|
}
|
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
|
|
|
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove};
|
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
|
markOpToRemove(computeOp);
|
|
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
|
markOpToRemove(computeBatchOp);
|
|
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
compactSpatialTensorGroups(funcOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
|
receiveOps.push_back(op);
|
|
for (auto receiveOp : receiveOps) {
|
|
bool onlyPendingRemovalUsers = llvm::all_of(
|
|
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
|
|
if (onlyPendingRemovalUsers) {
|
|
markOpToRemove(receiveOp);
|
|
continue;
|
|
}
|
|
if (receiveOp->use_empty()) {
|
|
rewriter.eraseOp(receiveOp);
|
|
continue;
|
|
}
|
|
lowerChannelReceive(receiveOp, rewriter);
|
|
}
|
|
|
|
SmallVector<spatial::SpatChannelReceiveTensorOp> receiveTensorOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveTensorOp>())
|
|
receiveTensorOps.push_back(op);
|
|
for (auto receiveTensorOp : receiveTensorOps)
|
|
lowerChannelReceiveTensor(receiveTensorOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendOp> sendOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
|
|
sendOps.push_back(op);
|
|
for (auto sendOp : sendOps)
|
|
lowerChannelSend(sendOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatChannelSendTensorOp> sendTensorOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatChannelSendTensorOp>())
|
|
sendTensorOps.push_back(op);
|
|
for (auto sendTensorOp : sendTensorOps)
|
|
lowerChannelSendTensor(sendTensorOp, rewriter);
|
|
|
|
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
|
|
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
|
|
extractRowsOps.push_back(op);
|
|
for (auto extractRowsOp : extractRowsOps)
|
|
lowerExtractRows(extractRowsOp, rewriter);
|
|
|
|
{
|
|
RewritePatternSet coreBodyPatterns(ctx);
|
|
populateWithGenerated(coreBodyPatterns);
|
|
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
|
|
|
SmallVector<pim::PimCoreOp> coreOps;
|
|
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
|
for (auto coreOp : coreOps) {
|
|
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
|
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
|
for (auto coreBatchOp : coreBatchOps) {
|
|
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
|
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
|
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
|
|
|
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
|
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
compactSpatialTensorGroups(funcOp, rewriter);
|
|
|
|
{
|
|
ConversionTarget communicationTarget(*ctx);
|
|
communicationTarget.addLegalDialect<PimDialect,
|
|
tensor::TensorDialect,
|
|
arith::ArithDialect,
|
|
bufferization::BufferizationDialect,
|
|
func::FuncDialect,
|
|
memref::MemRefDialect,
|
|
scf::SCFDialect,
|
|
BuiltinDialect>();
|
|
communicationTarget.addLegalOp<ModuleOp>();
|
|
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
|
spatial::SpatChannelReceiveOp,
|
|
spatial::SpatChannelReceiveTensorOp,
|
|
spatial::SpatChannelSendOp,
|
|
spatial::SpatChannelSendTensorOp,
|
|
spatial::SpatExtractRowsOp>();
|
|
|
|
RewritePatternSet communicationPatterns(ctx);
|
|
populateChannelLoweringPatterns(communicationPatterns);
|
|
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
}
|
|
|
|
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
// Dump to file for debug
|
|
dumpModule(moduleOp, "pim0");
|
|
}
|
|
|
|
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
|
auto* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return;
|
|
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
|
|
if (!dpsDefiningOp)
|
|
return;
|
|
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
|
if (!tiedOperand)
|
|
return;
|
|
Value tiedValue = tiedOperand->get();
|
|
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
|
|
tiedValue.setType(newType);
|
|
self(tiedValue, newType, self);
|
|
};
|
|
|
|
funcOp.walk([&](PimVMMOp vmmOp) {
|
|
auto outTensorOperand = vmmOp.getOutputBuffer();
|
|
auto resultTensor = vmmOp.getOutput();
|
|
auto outShape = getTensorShape(outTensorOperand);
|
|
assert(isHVectorShape(outShape));
|
|
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
|
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
|
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
|
if (outTensorOperand == vmmOp.getInput()) {
|
|
rewriter.setInsertionPoint(vmmOp);
|
|
auto newOutputBuffer =
|
|
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
|
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
|
}
|
|
else {
|
|
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
|
outTensorOperand.setType(newType);
|
|
}
|
|
resultTensor.setType(newType);
|
|
|
|
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
|
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
|
|
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
|
|
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
|
|
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
|
|
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
|
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
|
rewriter.setInsertionPointAfter(vmmOp);
|
|
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
|
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
|
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
|
}
|
|
});
|
|
}
|
|
|
|
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
Location loc = funcOp.getLoc();
|
|
|
|
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
|
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
|
Type elementType = tensorType.getElementType();
|
|
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
|
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
|
|
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
|
|
|
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
|
rewriter,
|
|
loc,
|
|
tensorType,
|
|
deviceTensor,
|
|
inputTensor,
|
|
rewriter.getI32IntegerAttr(0),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
|
|
|
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
|
};
|
|
|
|
for (auto& op : funcOp.getBody().getOps())
|
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
|
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
|
continue;
|
|
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
|
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
|
|
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
|
|
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
|
|
insertMemCopyHostToDev(toTensorOpValue, 0);
|
|
}
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
|
if (!llvm::is_contained(operationsToRemove, op))
|
|
operationsToRemove.push_back(op);
|
|
}
|
|
|
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
|
|
|
} // namespace onnx_mlir
|