sightly better bufferization

minor fixes
This commit is contained in:
NiccoloN
2026-05-07 17:53:47 +02:00
parent f2fe147961
commit f6c8cc4aa5
19 changed files with 150 additions and 141 deletions

View File

@@ -20,7 +20,7 @@ namespace spatial {
namespace {
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
@@ -45,7 +45,7 @@ inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
@@ -177,10 +177,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
}
for (auto& bodyOp : block) {
if (auto wvmm = dyn_cast<SpatWeightedVMMOp>(&bodyOp))
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
if (auto wmvm = dyn_cast<SpatWeightedMVMOp>(&bodyOp))
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
}
@@ -189,10 +189,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
} // namespace
LogicalResult SpatWeightedMVMOp::verify() {
LogicalResult SpatMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -204,10 +204,10 @@ LogicalResult SpatWeightedMVMOp::verify() {
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatWeightedVMMOp::verify() {
LogicalResult SpatVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();