generic gemm now works :)

This commit is contained in:
NiccoloN
2026-03-06 18:23:27 +01:00
parent 825188cc89
commit 1348bb1c97
5 changed files with 66 additions and 45 deletions

View File

@@ -24,8 +24,6 @@ namespace onnx_mlir {
namespace spatial {
void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();

View File

@@ -257,9 +257,16 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!returnValueDefiningOp->hasAttr("weightAlways"));
outputTensors.push_back(returnValue);
}
else {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
}
}
}