big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -116,7 +116,7 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
|
||||
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
|
||||
for (int32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
|
||||
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
|
||||
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
|
||||
return coreIds;
|
||||
}
|
||||
@@ -150,40 +150,33 @@ static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewri
|
||||
|
||||
static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) {
|
||||
rewriter.setInsertionPoint(sendManyOp);
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendManyOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
PimSendManyOp::create(
|
||||
rewriter, sendManyOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), sendManyOp.getInputs());
|
||||
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) {
|
||||
PimSendOp::create(rewriter,
|
||||
sendManyOp.getLoc(),
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input),
|
||||
rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId)));
|
||||
}
|
||||
rewriter.eraseOp(sendManyOp);
|
||||
}
|
||||
|
||||
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter, Location loc, TypeRange outputTypes) {
|
||||
SmallVector<Type> tensorTypes;
|
||||
tensorTypes.reserve(outputTypes.size());
|
||||
for (Type outputType : outputTypes)
|
||||
tensorTypes.push_back(outputType);
|
||||
|
||||
auto emptyMany = pim::PimEmptyManyOp::create(rewriter, loc, TypeRange(tensorTypes));
|
||||
return SmallVector<Value>(emptyMany.getOutputs().begin(), emptyMany.getOutputs().end());
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) {
|
||||
rewriter.setInsertionPoint(receiveManyOp);
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
SmallVector<Value> outputBuffers =
|
||||
createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
|
||||
|
||||
auto receiveMany = PimReceiveManyOp::create(rewriter,
|
||||
receiveManyOp.getLoc(),
|
||||
receiveManyOp.getResultTypes(),
|
||||
ValueRange(outputBuffers),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
|
||||
rewriter.replaceOp(receiveManyOp, receiveMany.getOutputs());
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(receiveManyOp.getNumResults());
|
||||
for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) {
|
||||
auto outputType = cast<ShapedType>(output.getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType).getResult();
|
||||
replacements.push_back(
|
||||
PimReceiveOp::create(rewriter,
|
||||
receiveManyOp.getLoc(),
|
||||
output.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, output),
|
||||
rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId)))
|
||||
.getOutput());
|
||||
}
|
||||
rewriter.replaceOp(receiveManyOp, replacements);
|
||||
}
|
||||
|
||||
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
|
||||
@@ -198,8 +191,17 @@ static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendMa
|
||||
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
|
||||
for (Value input : sendManyBatchOp.getInputs())
|
||||
mappedInputs.push_back(mapper.lookup(input));
|
||||
pim::PimSendManyBatchOp::create(
|
||||
rewriter, sendManyBatchOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), ValueRange(mappedInputs));
|
||||
for (auto [valueIndex, input] : llvm::enumerate(mappedInputs)) {
|
||||
SmallVector<int32_t> laneTargetCoreIds;
|
||||
laneTargetCoreIds.reserve(laneCount);
|
||||
for (int32_t lane = 0; lane < laneCount; ++lane)
|
||||
laneTargetCoreIds.push_back(targetCoreIds[valueIndex * laneCount + lane]);
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
sendManyBatchOp.getLoc(),
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input),
|
||||
rewriter.getDenseI32ArrayAttr(laneTargetCoreIds));
|
||||
}
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
|
||||
@@ -210,29 +212,44 @@ static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp
|
||||
sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
SmallVector<Value> outputBuffers =
|
||||
createManyEmptyTensorsLike(rewriter, receiveManyBatchOp.getLoc(), receiveManyBatchOp.getResultTypes());
|
||||
|
||||
auto receiveMany = pim::PimReceiveManyBatchOp::create(rewriter,
|
||||
receiveManyBatchOp.getLoc(),
|
||||
receiveManyBatchOp.getResultTypes(),
|
||||
ValueRange(outputBuffers),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
|
||||
for (auto [output, received] : llvm::zip(receiveManyBatchOp.getOutputs(), receiveMany.getOutputs()))
|
||||
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
|
||||
auto outputType = cast<ShapedType>(output.getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType).getResult();
|
||||
SmallVector<int32_t> laneSourceCoreIds;
|
||||
laneSourceCoreIds.reserve(laneCount);
|
||||
for (int32_t lane = 0; lane < laneCount; ++lane)
|
||||
laneSourceCoreIds.push_back(sourceCoreIds[valueIndex * laneCount + lane]);
|
||||
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
receiveManyBatchOp.getLoc(),
|
||||
output.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, output),
|
||||
rewriter.getDenseI32ArrayAttr(laneSourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(output, received);
|
||||
}
|
||||
}
|
||||
|
||||
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
||||
rewriter.setInsertionPoint(extractRowsOp);
|
||||
SmallVector<Value> outputBuffers =
|
||||
createManyEmptyTensorsLike(rewriter, extractRowsOp.getLoc(), extractRowsOp.getResultTypes());
|
||||
|
||||
auto extractRows = pim::PimExtractRowsOp::create(rewriter,
|
||||
extractRowsOp.getLoc(),
|
||||
extractRowsOp.getResultTypes(),
|
||||
extractRowsOp.getInput(),
|
||||
ValueRange(outputBuffers));
|
||||
rewriter.replaceOp(extractRowsOp, extractRows.getOutputs());
|
||||
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(extractRowsOp.getNumResults());
|
||||
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
||||
auto outputType = cast<RankedTensorType>(output.getType());
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(inputType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
replacements.push_back(
|
||||
tensor::ExtractSliceOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult());
|
||||
}
|
||||
rewriter.replaceOp(extractRowsOp, replacements);
|
||||
}
|
||||
|
||||
static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
|
||||
@@ -258,14 +275,26 @@ static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
|
||||
for (auto mapOp : mapOps) {
|
||||
Block& body = mapOp.getBody().front();
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
auto pimMap = pim::PimMapOp::create(rewriter, mapOp.getLoc(), mapOp.getResultTypes(), mapOp.getInputs());
|
||||
rewriter.inlineRegionBefore(mapOp.getBody(), pimMap.getBody(), pimMap.getBody().begin());
|
||||
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
rewriter.replaceOpWithNewOp<pim::PimYieldOp>(yieldOp, yieldOp.getOutputs());
|
||||
rewriter.replaceOp(mapOp, pimMap.getOutputs());
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(mapOp.getInputs().size());
|
||||
rewriter.setInsertionPoint(mapOp);
|
||||
|
||||
for (Value input : mapOp.getInputs()) {
|
||||
IRMapping mapping;
|
||||
mapping.map(body.getArgument(0), input);
|
||||
|
||||
for (Operation& bodyOp : body.without_terminator()) {
|
||||
Operation* cloned = rewriter.clone(bodyOp, mapping);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults()))
|
||||
mapping.map(originalResult, clonedResult);
|
||||
rewriter.setInsertionPointAfter(cloned);
|
||||
}
|
||||
|
||||
replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0)));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(mapOp, replacements);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,7 +324,7 @@ static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigne
|
||||
}
|
||||
|
||||
static Value createPackedExtractRowsSlice(
|
||||
pim::PimExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
||||
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
||||
@@ -332,14 +361,17 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
|
||||
if (!getContiguousOpResults(values, owner, startIndex))
|
||||
return {};
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
||||
return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast<unsigned>(values.size()), rewriter, loc);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
static Value createPackedReceiveTensor(
|
||||
pim::PimReceiveManyOp receiveManyOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
static Value createPackedReceiveTensor(spatial::SpatChannelReceiveManyOp receiveManyOp,
|
||||
unsigned startIndex,
|
||||
unsigned count,
|
||||
IRRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto rowType = dyn_cast<RankedTensorType>(receiveManyOp.getOutputs()[startIndex].getType());
|
||||
if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0)
|
||||
return {};
|
||||
@@ -351,15 +383,15 @@ static Value createPackedReceiveTensor(
|
||||
sourceCoreIds.reserve(count);
|
||||
ArrayRef<int32_t> allSourceCoreIds = receiveManyOp.getSourceCoreIds();
|
||||
for (unsigned index = 0; index < count; ++index)
|
||||
sourceCoreIds.push_back(allSourceCoreIds[startIndex + index]);
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(allSourceCoreIds[startIndex + index]));
|
||||
|
||||
return pim::PimReceiveTensorOp::create(
|
||||
rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static Value
|
||||
createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
static Value createPackedMapTensor(
|
||||
spatial::SpatMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) {
|
||||
Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc);
|
||||
if (!packedInput)
|
||||
return {};
|
||||
@@ -416,7 +448,7 @@ createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count,
|
||||
rewriter.setInsertionPointAfter(cloned);
|
||||
}
|
||||
|
||||
auto yieldOp = cast<pim::PimYieldOp>(body.getTerminator());
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
|
||||
Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0));
|
||||
|
||||
int64_t outputRowsPerValue = outputType.getDimSize(0);
|
||||
@@ -446,9 +478,9 @@ createPackedMapTensor(pim::PimMapOp mapOp, unsigned startIndex, unsigned count,
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<pim::PimSendManyOp> sendManyOps;
|
||||
funcOp.walk([&](pim::PimSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); });
|
||||
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatChannelSendManyOp> sendManyOps;
|
||||
funcOp.walk([&](spatial::SpatChannelSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); });
|
||||
for (auto sendManyOp : sendManyOps) {
|
||||
if (sendManyOp.getInputs().empty())
|
||||
continue;
|
||||
@@ -458,12 +490,17 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
if (!packedInput)
|
||||
continue;
|
||||
|
||||
pim::PimSendTensorOp::create(rewriter, sendManyOp.getLoc(), packedInput, sendManyOp.getTargetCoreIdsAttr());
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendManyOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
pim::PimSendTensorOp::create(
|
||||
rewriter, sendManyOp.getLoc(), packedInput, rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.eraseOp(sendManyOp);
|
||||
}
|
||||
|
||||
SmallVector<pim::PimConcatOp> concatOps;
|
||||
funcOp.walk([&](pim::PimConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
for (auto concatOp : concatOps) {
|
||||
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
||||
continue;
|
||||
@@ -494,11 +531,11 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
|
||||
unsigned count = endIndex - index;
|
||||
Value packedInput;
|
||||
if (auto mapOp = dyn_cast<pim::PimMapOp>(owner))
|
||||
if (auto mapOp = dyn_cast<spatial::SpatMapOp>(owner))
|
||||
packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
else if (auto receiveManyOp = dyn_cast<pim::PimReceiveManyOp>(owner))
|
||||
else if (auto receiveManyOp = dyn_cast<spatial::SpatChannelReceiveManyOp>(owner))
|
||||
packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
else if (auto extractRowsOp = dyn_cast<pim::PimExtractRowsOp>(owner))
|
||||
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
||||
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
|
||||
if (packedInput) {
|
||||
@@ -516,12 +553,14 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
if (!changed)
|
||||
continue;
|
||||
|
||||
auto newConcat = pim::PimConcatOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
concatOp.getOutputBuffer());
|
||||
auto newConcat = pim::PimConcatOp::create(
|
||||
rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast<ShapedType>(concatOp.getOutput().getType()))
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
}
|
||||
|
||||
@@ -533,10 +572,9 @@ static void compactPimTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
};
|
||||
eraseUnusedOps(pim::PimMapOp {});
|
||||
eraseUnusedOps(pim::PimReceiveManyOp {});
|
||||
eraseUnusedOps(pim::PimExtractRowsOp {});
|
||||
eraseUnusedOps(pim::PimEmptyManyOp {});
|
||||
eraseUnusedOps(spatial::SpatMapOp {});
|
||||
eraseUnusedOps(spatial::SpatChannelReceiveManyOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
}
|
||||
|
||||
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
@@ -617,6 +655,7 @@ struct ConcatReturnUseInfo {
|
||||
size_t returnIndex;
|
||||
SmallVector<int64_t> sliceOffsets;
|
||||
SmallVector<int64_t> concatShape;
|
||||
SmallVector<Operation*> concatChain;
|
||||
SmallVector<Operation*> helperChain;
|
||||
};
|
||||
|
||||
@@ -669,6 +708,8 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
auto getConcatResult = [](Operation* op) -> Value {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getResult();
|
||||
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return spatialConcat.getOutput();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getOutput();
|
||||
return {};
|
||||
@@ -676,6 +717,8 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
auto getConcatAxis = [](Operation* op) -> std::optional<int64_t> {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getDim();
|
||||
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return spatialConcat.getAxis();
|
||||
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
|
||||
return pimConcat.getAxis();
|
||||
return std::nullopt;
|
||||
@@ -683,11 +726,14 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
auto getConcatOperands = [](Operation* op) -> OperandRange {
|
||||
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
|
||||
return tensorConcat.getOperands();
|
||||
if (auto spatialConcat = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return spatialConcat.getInputs();
|
||||
return cast<pim::PimConcatOp>(op).getInputs();
|
||||
};
|
||||
|
||||
auto uses = value.getUses();
|
||||
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
|
||||
if (rangeLength(uses) != 1
|
||||
|| !isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
|
||||
return std::nullopt;
|
||||
|
||||
auto valueType = dyn_cast<ShapedType>(value.getType());
|
||||
@@ -696,10 +742,12 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
|
||||
SmallVector<int64_t> sliceOffsets(valueType.getRank(), 0);
|
||||
SmallVector<int64_t> concatShape(valueType.getShape().begin(), valueType.getShape().end());
|
||||
SmallVector<Operation*> concatChain;
|
||||
Value currentValue = value;
|
||||
Operation* currentUser = uses.begin()->getOwner();
|
||||
|
||||
while (isa<tensor::ConcatOp, pim::PimConcatOp>(currentUser)) {
|
||||
while (isa<tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp>(currentUser)) {
|
||||
concatChain.push_back(currentUser);
|
||||
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
|
||||
int64_t axis = *getConcatAxis(currentUser);
|
||||
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
|
||||
@@ -749,6 +797,7 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
|
||||
currentValue.getUses().begin()->getOperandNumber(),
|
||||
std::move(sliceOffsets),
|
||||
std::move(concatShape),
|
||||
std::move(concatChain),
|
||||
std::move(helperChain),
|
||||
};
|
||||
}
|
||||
@@ -918,11 +967,6 @@ void SpatialToPimPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp op) { concatOps.push_back(op); });
|
||||
for (auto concatOp : concatOps)
|
||||
lowerConcat(concatOp, rewriter);
|
||||
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
runOnComputeOp(computeOp, rewriter);
|
||||
@@ -933,6 +977,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
runOnComputeBatchOp(computeBatchOp, rewriter);
|
||||
}
|
||||
|
||||
compactSpatialTensorGroups(funcOp, rewriter);
|
||||
lowerMapOps(funcOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||
@@ -1036,6 +1081,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
assert(false && "tracked op removal reached a cycle or missed dependency");
|
||||
}
|
||||
|
||||
compactSpatialTensorGroups(funcOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatConcatOp> remainingConcatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); });
|
||||
for (auto concatOp : remainingConcatOps)
|
||||
@@ -1066,8 +1113,6 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto extractRowsOp : remainingExtractRowsOps)
|
||||
lowerExtractRows(extractRowsOp, rewriter);
|
||||
|
||||
compactPimTensorGroups(funcOp, rewriter);
|
||||
|
||||
// Dump to file for debug
|
||||
bool hasSpatialOps = false;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
@@ -1170,6 +1215,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(concatOp);
|
||||
|
||||
if (concatReturnUse->helperChain.empty()) {
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
@@ -1481,13 +1528,15 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
||||
Value currentReturnValue = returnValue;
|
||||
Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp();
|
||||
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||
assert(!hasWeightAlways(returnValueDefiningOp));
|
||||
outputTensors.push_back([returnValue](IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
|
||||
outputTensors.push_back(
|
||||
[currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; });
|
||||
}
|
||||
else {
|
||||
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
|
||||
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(currentReturnValue.getType());
|
||||
auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
|
||||
|
||||
std::string outputName = "output_" + std::to_string(index);
|
||||
@@ -1565,7 +1614,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
|
||||
Operation* onlyUser = *op->getUsers().begin();
|
||||
isExclusivelyOwnedByReturnChain =
|
||||
isa<func::ReturnOp, tensor::ConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
||||
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|
||||
|| isChannelUseChainOp(onlyUser);
|
||||
}
|
||||
if (!isExclusivelyOwnedByReturnChain)
|
||||
@@ -1593,6 +1642,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
|
||||
Reference in New Issue
Block a user