multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s

This commit is contained in:
NiccoloN
2026-04-23 09:28:57 +02:00
parent 0f13269040
commit 412ca957f6
16 changed files with 415 additions and 420 deletions

View File

@@ -62,7 +62,7 @@ private:
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
@@ -73,7 +73,7 @@ private:
void annotateChannelCoreIds(func::FuncOp funcOp);
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
@@ -116,7 +116,7 @@ static size_t countComputeLeafUsers(Value value) {
auto walkUses = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner();
if (isa<spatial::SpatWeightedCompute>(owner)) {
if (isa<spatial::SpatCompute>(owner)) {
leafUserCount++;
continue;
}
@@ -174,7 +174,7 @@ void SpatialToPimPass::runOnOperation() {
markOpToRemove(receiveOp);
runOnReceiveOp(receiveOp, rewriter);
}
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
runOnComputeOp(computeOp, rewriter);
}
@@ -222,7 +222,7 @@ void SpatialToPimPass::runOnOperation() {
dumpModule(moduleOp, "pim0");
}
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto& block = computeOp.getRegion().front();
@@ -504,7 +504,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
unsigned numComputeWeights = computeOp.getWeights().size();
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
TypedValue<TensorType> tensorSource;
@@ -513,7 +513,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
@@ -538,7 +538,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
// Compute results must be transferred through channels via send/receive
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue;
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
@@ -553,7 +553,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
return success();
}
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
unsigned int argIndex,
Value channelSourceOp,
Value consumerValue,
@@ -614,7 +614,7 @@ void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner();
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
if (auto computeUser = dyn_cast<spatial::SpatCompute>(owner)) {
replaceBlockArgumentWithRecvOp(
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
continue;