compact pim IR
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m15s

This commit is contained in:
NiccoloN
2026-05-06 17:16:51 +02:00
parent 7bb58e80de
commit f2fe147961
13 changed files with 2264 additions and 307 deletions

View File

@@ -150,116 +150,105 @@ static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewri
static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(sendManyOp);
for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, input);
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId));
PimSendOp::create(rewriter, sendManyOp.getLoc(), input, sizeAttr, targetCoreIdAttr);
}
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendManyOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
PimSendManyOp::create(
rewriter, sendManyOp.getLoc(), rewriter.getDenseI32ArrayAttr(targetCoreIds), sendManyOp.getInputs());
rewriter.eraseOp(sendManyOp);
}
static SmallVector<Value> createManyEmptyTensorsLike(IRRewriter& rewriter,
Location loc,
TypeRange outputTypes) {
SmallVector<Type> tensorTypes;
tensorTypes.reserve(outputTypes.size());
for (Type outputType : outputTypes)
tensorTypes.push_back(outputType);
auto emptyMany = pim::PimEmptyManyOp::create(rewriter, loc, TypeRange(tensorTypes));
return SmallVector<Value>(emptyMany.getOutputs().begin(), emptyMany.getOutputs().end());
}
static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) {
SmallVector<Value> replacements;
replacements.reserve(receiveManyOp.getNumResults());
rewriter.setInsertionPoint(receiveManyOp);
for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) {
auto outputType = cast<ShapedType>(output.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId));
Value received =
PimReceiveOp::create(
rewriter, receiveManyOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
replacements.push_back(received);
}
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveManyOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveManyOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
SmallVector<Value> outputBuffers = createManyEmptyTensorsLike(rewriter, receiveManyOp.getLoc(), receiveManyOp.getResultTypes());
rewriter.replaceOp(receiveManyOp, ValueRange(replacements));
auto receiveMany = PimReceiveManyOp::create(rewriter,
receiveManyOp.getLoc(),
receiveManyOp.getResultTypes(),
ValueRange(outputBuffers),
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
rewriter.replaceOp(receiveManyOp, receiveMany.getOutputs());
}
static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp,
int32_t laneCount,
IRMapping& mapper,
IRRewriter& rewriter) {
auto targetCoreIds = sendManyBatchOp.getTargetCoreIds();
for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) {
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount);
pim::PimSendBatchOp::create(rewriter,
sendManyBatchOp.getLoc(),
mapper.lookup(input),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)),
rewriter.getDenseI32ArrayAttr(targetSlice));
}
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendManyBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendManyBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
SmallVector<Value> mappedInputs;
mappedInputs.reserve(sendManyBatchOp.getInputs().size());
for (Value input : sendManyBatchOp.getInputs())
mappedInputs.push_back(mapper.lookup(input));
pim::PimSendManyBatchOp::create(rewriter,
sendManyBatchOp.getLoc(),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
ValueRange(mappedInputs));
}
static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp,
int32_t laneCount,
IRMapping& mapper,
IRRewriter& rewriter) {
auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds();
for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) {
size_t metadataOffset = valueIndex * static_cast<size_t>(laneCount);
auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount);
auto outputType = cast<ShapedType>(output.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
receiveManyBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, output),
rewriter.getDenseI32ArrayAttr(sourceSlice))
.getOutput();
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
SmallVector<Value> outputBuffers =
createManyEmptyTensorsLike(rewriter, receiveManyBatchOp.getLoc(), receiveManyBatchOp.getResultTypes());
auto receiveMany = pim::PimReceiveManyBatchOp::create(rewriter,
receiveManyBatchOp.getLoc(),
receiveManyBatchOp.getResultTypes(),
ValueRange(outputBuffers),
rewriter.getDenseI32ArrayAttr(sourceCoreIds));
for (auto [output, received] : llvm::zip(receiveManyBatchOp.getOutputs(), receiveMany.getOutputs()))
mapper.map(output, received);
}
}
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
Value input = extractRowsOp.getInput();
RankedTensorType inputType;
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
inputType = tensorType;
}
else if (auto memRefType = dyn_cast<MemRefType>(input.getType())) {
inputType = RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
rewriter.setInsertionPoint(extractRowsOp);
input = bufferization::ToTensorOp::create(
rewriter, extractRowsOp.getLoc(), inputType, input, rewriter.getUnitAttr(), rewriter.getUnitAttr())
.getResult();
}
else {
extractRowsOp.emitOpError("requires a ranked tensor or memref input during Spatial-to-PIM lowering");
return;
}
int64_t numCols = inputType.getDimSize(1);
SmallVector<Value> replacements;
replacements.reserve(extractRowsOp.getNumResults());
rewriter.setInsertionPoint(extractRowsOp);
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType) {
extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering");
return;
}
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(static_cast<int64_t>(rowIndex)),
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto rowSlice =
tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides);
replacements.push_back(rowSlice.getResult());
}
SmallVector<Value> outputBuffers =
createManyEmptyTensorsLike(rewriter, extractRowsOp.getLoc(), extractRowsOp.getResultTypes());
rewriter.replaceOp(extractRowsOp, ValueRange(replacements));
auto extractRows = pim::PimExtractRowsOp::create(rewriter,
extractRowsOp.getLoc(),
extractRowsOp.getResultTypes(),
extractRowsOp.getInput(),
ValueRange(outputBuffers));
rewriter.replaceOp(extractRowsOp, extractRows.getOutputs());
}
static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) {
rewriter.setInsertionPoint(concatOp);
Value concatenated =
tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getAxis(), concatOp.getInputs()).getResult();
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
Value outputBuffer = createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), outputType).getResult();
Value concatenated = pim::PimConcatOp::create(rewriter,
concatOp.getLoc(),
concatOp.getOutput().getType(),
rewriter.getI64IntegerAttr(concatOp.getAxis()),
concatOp.getInputs(),
outputBuffer)
.getOutput();
rewriter.replaceOp(concatOp, concatenated);
}
@@ -282,34 +271,23 @@ static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewrit
}
}
static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) {
SmallVector<spatial::SpatMapOp> mapOps;
funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); });
funcOp.walk([&](spatial::SpatMapOp mapOp) {
if (mapOp->getParentOfType<pim::PimCoreOp>() || mapOp->getParentOfType<pim::PimCoreBatchOp>())
mapOps.push_back(mapOp);
});
for (auto mapOp : mapOps) {
Block& body = mapOp.getBody().front();
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
SmallVector<Value> replacements;
replacements.reserve(mapOp.getInputs().size());
rewriter.setInsertionPoint(mapOp);
for (Value input : mapOp.getInputs()) {
IRMapping mapping;
mapping.map(body.getArgument(0), input);
auto pimMap = pim::PimMapOp::create(rewriter, mapOp.getLoc(), mapOp.getResultTypes(), mapOp.getInputs());
rewriter.inlineRegionBefore(mapOp.getBody(), pimMap.getBody(), pimMap.getBody().begin());
Value replacement = input;
for (Operation& op : body.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapping);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapping.map(originalResult, clonedResult);
rewriter.setInsertionPointAfter(cloned);
}
replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
replacements.push_back(replacement);
}
rewriter.replaceOp(mapOp, replacements);
auto yieldOp = cast<spatial::SpatYieldOp>(body.getTerminator());
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<pim::PimYieldOp>(yieldOp, yieldOp.getOutputs());
rewriter.replaceOp(mapOp, pimMap.getOutputs());
}
}
@@ -440,8 +418,28 @@ static std::optional<ReturnUseInfo> analyzeReturnUse(Value value) {
}
static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
auto getConcatResult = [](Operation *op) -> Value {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getResult();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getOutput();
return {};
};
auto getConcatAxis = [](Operation *op) -> std::optional<int64_t> {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getDim();
if (auto pimConcat = dyn_cast<pim::PimConcatOp>(op))
return pimConcat.getAxis();
return std::nullopt;
};
auto getConcatOperands = [](Operation *op) -> OperandRange {
if (auto tensorConcat = dyn_cast<tensor::ConcatOp>(op))
return tensorConcat.getOperands();
return cast<pim::PimConcatOp>(op).getInputs();
};
auto uses = value.getUses();
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp>(uses.begin()->getOwner()))
if (rangeLength(uses) != 1 || !isa<tensor::ConcatOp, pim::PimConcatOp>(uses.begin()->getOwner()))
return std::nullopt;
auto valueType = dyn_cast<ShapedType>(value.getType());
@@ -453,18 +451,19 @@ static std::optional<ConcatReturnUseInfo> analyzeConcatReturnUse(Value value) {
Value currentValue = value;
Operation* currentUser = uses.begin()->getOwner();
while (auto concatOp = dyn_cast<tensor::ConcatOp>(currentUser)) {
while (isa<tensor::ConcatOp, pim::PimConcatOp>(currentUser)) {
size_t operandIndex = currentValue.getUses().begin()->getOperandNumber();
int64_t axis = concatOp.getDim();
for (Value operand : concatOp.getOperands().take_front(operandIndex))
int64_t axis = *getConcatAxis(currentUser);
for (Value operand : getConcatOperands(currentUser).take_front(operandIndex))
sliceOffsets[axis] += cast<ShapedType>(operand.getType()).getShape()[axis];
auto concatType = dyn_cast<ShapedType>(concatOp.getResult().getType());
Value concatResult = getConcatResult(currentUser);
auto concatType = dyn_cast<ShapedType>(concatResult.getType());
if (!concatType || !concatType.hasStaticShape())
return std::nullopt;
concatShape.assign(concatType.getShape().begin(), concatType.getShape().end());
currentValue = concatOp.getResult();
currentValue = concatResult;
auto currentUses = currentValue.getUses();
if (rangeLength(currentUses) != 1)
return std::nullopt;
@@ -638,7 +637,6 @@ void SpatialToPimPass::runOnOperation() {
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
expandMapOps(funcOp, rewriter);
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
@@ -687,6 +685,8 @@ void SpatialToPimPass::runOnOperation() {
runOnComputeBatchOp(computeBatchOp, rewriter);
}
lowerMapOps(funcOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
receiveOps.push_back(op);
@@ -1317,7 +1317,8 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) {
Operation* onlyUser = *op->getUsers().begin();
isExclusivelyOwnedByReturnChain =
isa<func::ReturnOp, tensor::ConcatOp, spatial::SpatCompute>(onlyUser) || isChannelUseChainOp(onlyUser);
isa<func::ReturnOp, tensor::ConcatOp, pim::PimConcatOp, spatial::SpatCompute>(onlyUser)
|| isChannelUseChainOp(onlyUser);
}
if (!isExclusivelyOwnedByReturnChain)
return;
@@ -1341,6 +1342,13 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
markOpToRemove(concatOp);
for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
}
};