sightly better bufferization
minor fixes
This commit is contained in:
@@ -133,7 +133,7 @@ CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& block : body)
|
||||
for (auto& op : block)
|
||||
if (isa<SpatWeightedVMMOp>(op))
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& region : spatCompute.getBody())
|
||||
for (auto& inst : region)
|
||||
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
|
||||
if (llvm::isa<onnx_mlir::spatial::SpatVMMOp>(inst))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
@@ -838,9 +838,9 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
|
||||
for (auto& op : child.getBody().front()) {
|
||||
auto newInst = rewriter.clone(op, mapper);
|
||||
if (auto weightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(newInst))
|
||||
if (auto weightedMvmOp = dyn_cast<spatial::SpatMVMOp>(newInst))
|
||||
remapWeightIndex(weightedMvmOp);
|
||||
if (auto weightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(newInst))
|
||||
if (auto weightedVmmOp = dyn_cast<spatial::SpatVMMOp>(newInst))
|
||||
remapWeightIndex(weightedVmmOp);
|
||||
}
|
||||
|
||||
@@ -884,9 +884,9 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
||||
ComputeMotifInfo& info = computeInfos[index];
|
||||
for (Operation& op : compute.getBody().front()) {
|
||||
info.instructionCount++;
|
||||
if (isa<spatial::SpatWeightedMVMOp>(&op))
|
||||
if (isa<spatial::SpatMVMOp>(&op))
|
||||
info.weightedMvmCount++;
|
||||
if (isa<spatial::SpatWeightedVMMOp>(&op))
|
||||
if (isa<spatial::SpatVMMOp>(&op))
|
||||
info.weightedVmmCount++;
|
||||
}
|
||||
if (info.weightedVmmCount > 0) {
|
||||
@@ -1617,13 +1617,13 @@ public:
|
||||
}
|
||||
|
||||
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
||||
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
||||
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
@@ -1643,22 +1643,22 @@ public:
|
||||
}
|
||||
|
||||
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
|
||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||
}
|
||||
|
||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) {
|
||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
RegularChunk chunk;
|
||||
chunk.startOp = startOp.getOperation();
|
||||
chunk.input = startOp.getInput();
|
||||
@@ -376,7 +376,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
|
||||
auto compactInBlock = [&](Block& block) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto startOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
||||
auto startOp = dyn_cast<spatial::SpatVMMOp>(&*it);
|
||||
if (!startOp) {
|
||||
++it;
|
||||
continue;
|
||||
@@ -391,7 +391,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
SmallVector<RegularChunk> run {*anchorChunk};
|
||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
while (runIt != block.end()) {
|
||||
auto candidateStart = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
||||
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!candidateStart)
|
||||
break;
|
||||
|
||||
@@ -425,7 +425,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto wvmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
||||
auto wvmmOp = dyn_cast<spatial::SpatVMMOp>(&*it);
|
||||
if (!wvmmOp) {
|
||||
++it;
|
||||
continue;
|
||||
@@ -440,11 +440,11 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatWeightedVMMOp> run;
|
||||
SmallVector<spatial::SpatVMMOp> run;
|
||||
auto runIt = it;
|
||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
||||
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||
@@ -545,7 +545,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
extractOffsets,
|
||||
extractSizes,
|
||||
extractStrides);
|
||||
auto loopWvmm = spatial::SpatWeightedVMMOp::create(
|
||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||
|
||||
Reference in New Issue
Block a user