merge remote changes
This commit is contained in:
@@ -1,20 +1,26 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
@@ -24,6 +30,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -53,7 +60,7 @@ struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>>
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
SmallVector<Value> outputTensors;
|
||||
SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
|
||||
size_t coreId = 0;
|
||||
SmallVector<Operation*> operationsToRemove;
|
||||
|
||||
@@ -179,7 +186,22 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan
|
||||
}
|
||||
|
||||
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
||||
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
Value input = extractRowsOp.getInput();
|
||||
RankedTensorType inputType;
|
||||
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
|
||||
inputType = tensorType;
|
||||
}
|
||||
else if (auto memRefType = dyn_cast<MemRefType>(input.getType())) {
|
||||
inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
|
||||
rewriter.setInsertionPoint(extractRowsOp);
|
||||
input = bufferization::ToTensorOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr())
|
||||
.getResult();
|
||||
}
|
||||
else {
|
||||
extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering");
|
||||
return;
|
||||
}
|
||||
int64_t numCols = inputType.getDimSize(1);
|
||||
|
||||
SmallVector<Value> replacements;
|
||||
@@ -187,11 +209,16 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite
|
||||
|
||||
rewriter.setInsertionPoint(extractRowsOp);
|
||||
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
||||
auto outputType = dyn_cast<RankedTensorType>(output.getType());
|
||||
if (!outputType) {
|
||||
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
|
||||
return;
|
||||
}
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto rowSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), cast<RankedTensorType>(output.getType()), extractRowsOp.getInput(), offsets, sizes, strides);
|
||||
rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
|
||||
replacements.push_back(rowSlice.getResult());
|
||||
}
|
||||
|
||||
@@ -205,6 +232,75 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
||||
rewriter.replaceOp(concatOp, concatenated);
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
SmallVectorImpl<Operation*>& helperChain,
|
||||
bool requireReturnUse = true) {
|
||||
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
|
||||
return failure();
|
||||
if (requireReturnUse
|
||||
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
|
||||
return failure();
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return failure();
|
||||
|
||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||
return failure();
|
||||
|
||||
SmallVector<Operation*> reverseChain;
|
||||
Value currentValue = yieldOp.getOperands().front();
|
||||
Value blockArg = block.getArgument(0);
|
||||
|
||||
while (currentValue != blockArg) {
|
||||
Operation* definingOp = currentValue.getDefiningOp();
|
||||
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
|
||||
return failure();
|
||||
reverseChain.push_back(definingOp);
|
||||
currentValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
|
||||
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
|
||||
if (!chainSet.contains(&op)
|
||||
&& !isa<tensor::EmptyOp, arith::ConstantOp>(op))
|
||||
return failure();
|
||||
|
||||
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||
return false;
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(),
|
||||
[](Operation* user) { return isa<spatial::SpatComputeBatch, pim::PimCoreBatchOp>(user); }))
|
||||
return false;
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
if (block.getNumArguments() != 0)
|
||||
return false;
|
||||
|
||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||
return false;
|
||||
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
IRMapping mapping;
|
||||
for (Operation& op : block.without_terminator()) {
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
|
||||
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||
computeOp.getResult(0).replaceAllUsesWith(replacement);
|
||||
return true;
|
||||
}
|
||||
|
||||
struct ReturnUseInfo {
|
||||
size_t returnIndex;
|
||||
SmallVector<Operation*> helperChain;
|
||||
@@ -295,6 +391,20 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
}
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
if (auto helperCompute = dyn_cast<spatial::SpatCompute>(currentUser)) {
|
||||
if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue)
|
||||
return std::nullopt;
|
||||
|
||||
if (failed(collectHelperComputeChain(helperCompute, helperChain)))
|
||||
return std::nullopt;
|
||||
|
||||
currentValue = helperCompute.getResult(0);
|
||||
auto currentUses = currentValue.getUses();
|
||||
if (rangeLength(currentUses) != 1)
|
||||
return std::nullopt;
|
||||
currentUser = currentUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
while (isChannelUseChainOp(currentUser)) {
|
||||
helperChain.push_back(currentUser);
|
||||
auto currentUses = currentUser->getResult(0).getUses();
|
||||
@@ -419,21 +529,22 @@ static void cloneHelperChain(Value sourceValue,
|
||||
}
|
||||
}
|
||||
|
||||
static void emitHostCopy(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Value outputTensor,
|
||||
Value sourceValue,
|
||||
int32_t hostTargetOffset,
|
||||
int32_t deviceSourceOffset,
|
||||
int32_t sizeInBytes) {
|
||||
PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
sourceValue,
|
||||
rewriter.getI32IntegerAttr(hostTargetOffset),
|
||||
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes));
|
||||
static Value emitHostCopy(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Value outputTensor,
|
||||
Value sourceValue,
|
||||
int32_t hostTargetOffset,
|
||||
int32_t deviceSourceOffset,
|
||||
int32_t sizeInBytes) {
|
||||
return PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
sourceValue,
|
||||
rewriter.getI32IntegerAttr(hostTargetOffset),
|
||||
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
@@ -458,12 +569,21 @@ void SpatialToPimPass::runOnOperation() {
|
||||
scf::SCFDialect,
|
||||
BuiltinDialect>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateGlobalTensorToMemrefPatterns(patterns);
|
||||
|
||||
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
||||
}
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
@@ -489,7 +609,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { receiveOps.push_back(op); });
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||
receiveOps.push_back(op);
|
||||
for (auto receiveOp : receiveOps) {
|
||||
bool onlyPendingRemovalUsers = llvm::all_of(
|
||||
receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); });
|
||||
@@ -505,22 +626,26 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveManyOp> receiveManyOps;
|
||||
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { receiveManyOps.push_back(op); });
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveManyOp>())
|
||||
receiveManyOps.push_back(op);
|
||||
for (auto receiveManyOp : receiveManyOps)
|
||||
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelSendOp> sendOps;
|
||||
funcOp.walk([&](spatial::SpatChannelSendOp op) { sendOps.push_back(op); });
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
|
||||
sendOps.push_back(op);
|
||||
for (auto sendOp : sendOps)
|
||||
lowerChannelSend(sendOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
|
||||
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { sendManyOps.push_back(op); });
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelSendManyOp>())
|
||||
sendManyOps.push_back(op);
|
||||
for (auto sendManyOp : sendManyOps)
|
||||
lowerChannelSendMany(sendManyOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
|
||||
funcOp.walk([&](spatial::SpatExtractRowsOp op) { extractRowsOps.push_back(op); });
|
||||
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
|
||||
extractRowsOps.push_back(op);
|
||||
for (auto extractRowsOp : extractRowsOps)
|
||||
lowerExtractRows(extractRowsOp, rewriter);
|
||||
|
||||
@@ -560,6 +685,36 @@ void SpatialToPimPass::runOnOperation() {
|
||||
assert(false && "tracked op removal reached a cycle or missed dependency");
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatConcatOp> remainingConcatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); });
|
||||
for (auto concatOp : remainingConcatOps)
|
||||
lowerConcat(concatOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> remainingReceiveOps;
|
||||
funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); });
|
||||
for (auto receiveOp : remainingReceiveOps)
|
||||
lowerChannelReceive(receiveOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveManyOp> remainingReceiveManyOps;
|
||||
funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); });
|
||||
for (auto receiveManyOp : remainingReceiveManyOps)
|
||||
lowerChannelReceiveMany(receiveManyOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelSendOp> remainingSendOps;
|
||||
funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); });
|
||||
for (auto sendOp : remainingSendOps)
|
||||
lowerChannelSend(sendOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelSendManyOp> remainingSendManyOps;
|
||||
funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); });
|
||||
for (auto sendManyOp : remainingSendManyOps)
|
||||
lowerChannelSendMany(sendManyOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatExtractRowsOp> remainingExtractRowsOps;
|
||||
funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); });
|
||||
for (auto extractRowsOp : remainingExtractRowsOps)
|
||||
lowerExtractRows(extractRowsOp, rewriter);
|
||||
|
||||
// Dump to file for debug
|
||||
bool hasSpatialOps = false;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
@@ -579,6 +734,13 @@ void SpatialToPimPass::runOnOperation() {
|
||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter))
|
||||
return;
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
|
||||
return;
|
||||
|
||||
auto& block = computeOp.getRegion().front();
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
|
||||
@@ -616,9 +778,9 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
Value outputTensor = outputTensors[returnUse->returnIndex];
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
@@ -637,8 +799,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
@@ -654,13 +816,13 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
}
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex];
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||
|
||||
if (concatReturnUse->helperChain.empty()) {
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
@@ -671,7 +833,15 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
continue;
|
||||
}
|
||||
|
||||
auto storedType = cast<RankedTensorType>(yieldValue.getType());
|
||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||
if (!storedType) {
|
||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||
for (auto [dim, idx] : llvm::enumerate(sourceIndices))
|
||||
@@ -701,19 +871,18 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
|
||||
auto scalarTensorType =
|
||||
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
auto elementSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
||||
rewriter.setInsertionPointAfter(elementSlice);
|
||||
|
||||
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
elementSlice.getResult(),
|
||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(elementSize));
|
||||
outputTensor = emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
elementSlice.getResult(),
|
||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(elementSize));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -848,6 +1017,26 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
|
||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
auto clonedTensor = cloned->getResult(0);
|
||||
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
clonedTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
||||
.getOutput();
|
||||
mapper.map(toTensorOp.getResult(), copied);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (Value operand : op.getOperands()) {
|
||||
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
||||
continue;
|
||||
@@ -922,17 +1111,33 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
|
||||
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
||||
for (auto returnValue : returnOp->getOperands()) {
|
||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
||||
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||
assert(!hasWeightAlways(returnValueDefiningOp));
|
||||
outputTensors.push_back(returnValue);
|
||||
outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
|
||||
}
|
||||
else {
|
||||
auto newOutputTensor =
|
||||
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
||||
outputTensors.push_back(newOutputTensor);
|
||||
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
|
||||
auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
|
||||
|
||||
std::string outputName = "output_" + std::to_string(index);
|
||||
rewriter.setInsertionPoint(returnOp.getParentOp());
|
||||
memref::GlobalOp::create(rewriter,
|
||||
returnOp.getLoc(),
|
||||
rewriter.getStringAttr(outputName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
{},
|
||||
{},
|
||||
{});
|
||||
outputTensors.push_back(
|
||||
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
return toTensor.getResult();
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -940,11 +1145,11 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
|
||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
||||
auto tensorType = cast<ShapedType>(valueToReplace.getType());
|
||||
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||
Type elementType = tensorType.getElementType();
|
||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
||||
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||
|
||||
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||
|
||||
@@ -953,86 +1158,27 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
loc,
|
||||
tensorType,
|
||||
deviceTensor,
|
||||
hostTensor,
|
||||
inputTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||
|
||||
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult());
|
||||
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||
};
|
||||
|
||||
// Replace input tensors with memRefs
|
||||
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
|
||||
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
|
||||
BlockArgument tensorArg = funcOp.getArgument(i);
|
||||
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
|
||||
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
|
||||
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
|
||||
|
||||
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
|
||||
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
|
||||
BlockArgument memRefArg = funcOp.getArgument(i + 1);
|
||||
|
||||
Block& block = funcOp.getBody().front();
|
||||
rewriter.setInsertionPoint(&block.front());
|
||||
auto toTensorOp =
|
||||
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
||||
inputTensors.push_back(toTensorOp);
|
||||
|
||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
||||
if (failed(funcOp.eraseArgument(i)))
|
||||
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
|
||||
}
|
||||
|
||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||
for (auto& op : funcOp.getBody().getOps())
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
unsigned numComputeWeights = computeOp.getWeights().size();
|
||||
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
|
||||
TypedValue<TensorType> tensorSource;
|
||||
int64_t elementsOffset = 0;
|
||||
|
||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||
|
||||
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
|
||||
assert("Extracting slice non-contiguous in memory"
|
||||
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
|
||||
|
||||
for (size_t i = 0; i < sliceOffsets.size(); i++) {
|
||||
int64_t partialOffset = sliceOffsets[i];
|
||||
if (partialOffset != 0)
|
||||
for (size_t j = i + 1; j < sourceShape.size(); j++)
|
||||
partialOffset *= sourceShape[j];
|
||||
elementsOffset += partialOffset;
|
||||
}
|
||||
|
||||
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
|
||||
sliceOpsToRemove.insert(sliceOp);
|
||||
if (!computeOp.getInputs().empty() || computeOp.getBody().front().getNumArguments() != 0)
|
||||
continue;
|
||||
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
|
||||
if (getGlobal.getName().starts_with("arg") || getGlobal.getName().starts_with("const_")) {
|
||||
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
|
||||
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
|
||||
insertMemCopyHostToDev(toTensorOpValue, 0);
|
||||
}
|
||||
else
|
||||
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
|
||||
|
||||
// Values already produced inside the device-side graph must not be
|
||||
// copied back through a host-to-device staging step here.
|
||||
if (isa<spatial::SpatCompute, spatial::SpatChannelReceiveOp>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
|
||||
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto sliceOp : sliceOpsToRemove)
|
||||
if (sliceOp->getUses().empty())
|
||||
rewriter.eraseOp(sliceOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1050,7 +1196,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||
Operation* onlyUser = *op->getUsers().begin();
|
||||
isExclusivelyOwnedByReturnChain =
|
||||
isa<func::ReturnOp, tensor::ConcatOp>(onlyUser) || isChannelUseChainOp(onlyUser);
|
||||
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatCompute>(onlyUser) || isChannelUseChainOp(onlyUser);
|
||||
}
|
||||
if (!isExclusivelyOwnedByReturnChain)
|
||||
return;
|
||||
@@ -1062,6 +1208,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
markOpToRemove(computeOp);
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getOperands())
|
||||
@@ -1070,12 +1223,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
};
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
auto loc = returnOp.getLoc();
|
||||
for (auto it : llvm::enumerate(originalOperands)) {
|
||||
size_t orderWithinReturn = it.index();
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
|
||||
rewriter.modifyOpInPlace(returnOp,
|
||||
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
||||
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user