fix much stuff
This commit is contained in:
@@ -422,12 +422,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) {
|
||||
auto weightArg = computeOp.getWeightArgument(aHSliceId);
|
||||
auto inputArg = computeOp.getInputArgument(aHSliceId);
|
||||
if (!weightArg || !inputArg)
|
||||
return failure();
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
|
||||
gemmLoc,
|
||||
currOutHSliceType,
|
||||
computeOp.getWeightArgument(aHSliceId),
|
||||
computeOp.getInputArgument(aHSliceId)));
|
||||
*weightArg,
|
||||
*inputArg));
|
||||
}
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
@@ -561,29 +566,31 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value lane = batchOp.getLaneArgument();
|
||||
Value weight = batchOp.getWeightArgument(0);
|
||||
Value packedInput = batchOp.getInputArgument(0);
|
||||
Value packedOutput = batchOp.getOutputArgument(0);
|
||||
auto lane = batchOp.getLaneArgument();
|
||||
auto weight = batchOp.getWeightArgument(0);
|
||||
auto packedInput = batchOp.getInputArgument(0);
|
||||
auto packedOutput = batchOp.getOutputArgument(0);
|
||||
if (!lane || !weight || !packedInput || !packedOutput)
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> inputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value row =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides)
|
||||
.getResult();
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
tensor::ParallelInsertSliceOp::create(
|
||||
rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
|
||||
rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
|
||||
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||
|
||||
@@ -27,13 +27,16 @@ static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
|
||||
return arg && canPromoteInputBlockArgument(*arg);
|
||||
}
|
||||
|
||||
static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
@@ -104,20 +107,30 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
@@ -184,12 +197,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto laneArg = compute.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(compute.getLaneArgument().getType());
|
||||
newBlockArgLocs.push_back(compute.getLaneArgument().getLoc());
|
||||
newBlockArgTypes.push_back(laneArg->getType());
|
||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
@@ -197,8 +213,11 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
|
||||
newBlockArgTypes.push_back(resultType);
|
||||
newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc());
|
||||
newBlockArgLocs.push_back(outputArg->getLoc());
|
||||
}
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
@@ -211,25 +230,41 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(compute.getLaneArgument(), newCompute.getLaneArgument());
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
auto newLaneArg = newCompute.getLaneArgument();
|
||||
if (!newLaneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
|
||||
mapper.map(*laneArg, *newLaneArg);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
||||
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
|
||||
mapper.map(compute.getOutputArgument(resultIndex),
|
||||
newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
|
||||
for (Operation& op : oldBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
@@ -183,17 +183,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Value> hostOutputTensors;
|
||||
SmallVector<unsigned> returnOperandIndices;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
hostOutputTensors.resize(computeBatchOp.getNumResults());
|
||||
returnOperandIndices.resize(computeBatchOp.getNumResults());
|
||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||
if (failed(returnOperandIndex))
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||
|
||||
hostOutputTensors[resultIndex] = outputTensors[*returnOperandIndex](rewriter, loc);
|
||||
result.replaceAllUsesWith(hostOutputTensors[resultIndex]);
|
||||
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,11 +207,20 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
|
||||
IRMapping mapper;
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
mapper.map(computeBatchOp.getLaneArgument(), coreBatchOp.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(computeBatchOp.getWeightArgument(weightIndex), coreBatchOp.getWeightArgument(weightIndex));
|
||||
auto oldLaneArg = computeBatchOp.getLaneArgument();
|
||||
if (!oldLaneArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
|
||||
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
|
||||
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
|
||||
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
|
||||
}
|
||||
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
|
||||
BlockArgument oldArg = computeBatchOp.getInputArgument(inputIndex);
|
||||
auto oldArg = computeBatchOp.getInputArgument(inputIndex);
|
||||
if (!oldArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
|
||||
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
|
||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||
@@ -226,7 +233,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||
.getOutput();
|
||||
mapper.map(oldArg, copied);
|
||||
mapper.map(*oldArg, copied);
|
||||
}
|
||||
|
||||
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
||||
@@ -248,13 +255,25 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
return copied;
|
||||
};
|
||||
|
||||
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
|
||||
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
|
||||
Value& hostOutputTensor = hostOutputTensors[resultIndex];
|
||||
if (hostOutputTensor)
|
||||
return hostOutputTensor;
|
||||
|
||||
hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
|
||||
return hostOutputTensor;
|
||||
};
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : oldBlock) {
|
||||
if (isa<spatial::SpatYieldOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||
unsigned firstOutputArg = computeBatchOp.getOutputArgument(0).getArgNumber();
|
||||
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
|
||||
for (Operation& nestedOp : parallelOp.getRegion().front()) {
|
||||
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
|
||||
if (!insertSlice)
|
||||
@@ -264,12 +283,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
if (!outputArg || outputArg.getOwner() != &oldBlock)
|
||||
return insertSlice.emitOpError("expected compute_batch output block argument destination");
|
||||
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg;
|
||||
if (resultIndex >= hostOutputTensors.size())
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||
if (resultIndex >= returnOperandIndices.size())
|
||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
auto hostTarget = hostOutputTensors[resultIndex];
|
||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -29,9 +31,17 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
unsigned inputIndex,
|
||||
Value replacement) {
|
||||
Block& body = owner->getRegion(0).front();
|
||||
BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner)
|
||||
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
|
||||
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
BlockArgument bodyArgument;
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
auto computeArg = compute.getInputArgument(inputIndex);
|
||||
assert(computeArg && "expected compute input block argument");
|
||||
bodyArgument = *computeArg;
|
||||
}
|
||||
else {
|
||||
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
assert(batchArg && "expected compute_batch input block argument");
|
||||
bodyArgument = *batchArg;
|
||||
}
|
||||
unsigned bodyArgIndex = bodyArgument.getArgNumber();
|
||||
|
||||
rewriter.startOpModification(owner);
|
||||
|
||||
@@ -131,8 +131,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
IRMapping mapping;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights()))
|
||||
mapping.map(computeOp.getWeightArgument(weightIndex), weight);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) {
|
||||
auto weightArg = computeOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
mapping.map(*weightArg, weight);
|
||||
}
|
||||
for (Operation& op : block.without_terminator()) {
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||
@@ -164,31 +168,33 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
|
||||
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (!blockArg)
|
||||
return computeOp.emitOpError("expected compute input block arguments during lowering");
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (receiveOp && !blockArg.use_empty()) {
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
if (receiveOp && !blockArg->use_empty()) {
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
|
||||
Value received =
|
||||
PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
|
||||
if (receiveTensorOp && !blockArg.use_empty()) {
|
||||
if (receiveTensorOp && !blockArg->use_empty()) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
|
||||
Value received = PimReceiveTensorOp::create(rewriter,
|
||||
receiveTensorOp.getLoc(),
|
||||
@@ -196,7 +202,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
}
|
||||
}
|
||||
@@ -238,12 +244,14 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (blockArg.use_empty())
|
||||
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (!blockArg)
|
||||
return computeOp.emitOpError("expected compute input block arguments during input materialization");
|
||||
if (blockArg->use_empty())
|
||||
continue;
|
||||
|
||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -261,7 +269,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input))
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(copied);
|
||||
blockArg->replaceAllUsesWith(copied);
|
||||
}
|
||||
if (!computeOp.getInputs().empty())
|
||||
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
|
||||
|
||||
@@ -77,8 +77,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
@@ -95,8 +97,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
|
||||
Reference in New Issue
Block a user