add support for operations: reduceMean, add, mul, div, sigmoid
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s

This commit is contained in:
NiccoloN
2026-03-30 15:41:12 +02:00
parent 5e7114f517
commit 39830be888
32 changed files with 1057 additions and 224 deletions

View File

@@ -161,26 +161,41 @@ void SpatialToPimPass::runOnOperation() {
}
for (auto receiveOp : funcOp.getOps<spatial::SpatChannelReceiveOp>()) {
operationsToRemove.push_back(receiveOp);
markOpToRemove(receiveOp);
runOnReceiveOp(receiveOp, rewriter);
}
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
operationsToRemove.push_back(computeOp);
markOpToRemove(computeOp);
runOnComputeOp(computeOp, rewriter);
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnOpOperands(returnOp, rewriter);
// Remove all ComputeOps
for (auto opToRemove : llvm::reverse(operationsToRemove)) {
if (!opToRemove->use_empty()) {
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
while (!pendingRemovals.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingRemovals.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (auto opToRemove : pendingRemovals) {
opToRemove->dump();
for (auto user : opToRemove->getUsers())
user->dump();
assert(false && "opToRemove should be unused at this point");
}
rewriter.eraseOp(opToRemove);
assert(false && "tracked op removal reached a cycle or missed dependency");
}
// Dump to file for debug
@@ -284,10 +299,19 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
auto concatUses = concatValue.getUses();
auto numConcatUses = rangeLength(concatUses);
if (numConcatUses == 1) {
OpOperand& concatUse = *concatUses.begin();
Operation* concatUser = concatUse.getOwner();
Value chainedValue = concatValue;
Operation* concatUser = concatUses.begin()->getOwner();
while (isChannelUseChainOp(concatUser)) {
auto chainUses = concatUser->getResult(0).getUses();
if (rangeLength(chainUses) != 1)
break;
chainedValue = concatUser->getResult(0);
concatUser = chainUses.begin()->getOwner();
}
if (isa<func::ReturnOp>(concatUser)) {
size_t concatIndexInReturn = concatUse.getOperandNumber();
size_t concatIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
size_t offset = 0;
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
@@ -602,10 +626,22 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
rewriter.modifyOpInPlace(returnOp,
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
if (isa<tensor::ConcatOp>(returnOperand)) {
auto returnOperandUses = it.value().getUses();
if (rangeLength(returnOperandUses) == 0)
rewriter.eraseOp(returnOperand);
Operation* opToErase = returnOperand;
while (opToErase) {
bool isExclusivelyOwnedByReturnChain = opToErase->use_empty() || opToErase->hasOneUse();
if (!isExclusivelyOwnedByReturnChain)
break;
if (isChannelUseChainOp(opToErase)) {
Value source = opToErase->getOperand(0);
markOpToRemove(opToErase);
opToErase = source.getDefiningOp();
continue;
}
if (isa<tensor::ConcatOp>(opToErase))
markOpToRemove(opToErase);
break;
}
}
}