sightly better bufferization
minor fixes
This commit is contained in:
@@ -253,7 +253,7 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> {
|
||||
// Math
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatWeightedVMMOp : SpatOp<"wvmm", []> {
|
||||
def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -272,7 +272,7 @@ def SpatWeightedVMMOp : SpatOp<"wvmm", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatWeightedMVMOp : SpatOp<"Wmvm", []> {
|
||||
def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
|
||||
inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter,
|
||||
ArrayRef<int64_t>& matrixShape,
|
||||
ArrayRef<int64_t>& vectorShape,
|
||||
ArrayRef<int64_t>& outputShape) {
|
||||
@@ -45,7 +45,7 @@ inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
|
||||
return success();
|
||||
}
|
||||
|
||||
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
||||
inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
||||
ArrayRef<int64_t>& matrixShape,
|
||||
ArrayRef<int64_t>& vectorShape,
|
||||
ArrayRef<int64_t>& outputShape) {
|
||||
@@ -177,10 +177,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
||||
}
|
||||
|
||||
for (auto& bodyOp : block) {
|
||||
if (auto wvmm = dyn_cast<SpatWeightedVMMOp>(&bodyOp))
|
||||
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
|
||||
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
||||
if (auto wmvm = dyn_cast<SpatWeightedMVMOp>(&bodyOp))
|
||||
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
|
||||
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
||||
}
|
||||
@@ -189,10 +189,10 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult SpatWeightedMVMOp::verify() {
|
||||
LogicalResult SpatMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
|
||||
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -204,10 +204,10 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
||||
return emitError("matrix rank must be 2 or 4");
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedVMMOp::verify() {
|
||||
LogicalResult SpatVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
|
||||
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
|
||||
@@ -133,7 +133,7 @@ CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& block : body)
|
||||
for (auto& op : block)
|
||||
if (isa<SpatWeightedVMMOp>(op))
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& region : spatCompute.getBody())
|
||||
for (auto& inst : region)
|
||||
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
|
||||
if (llvm::isa<onnx_mlir::spatial::SpatVMMOp>(inst))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
@@ -838,9 +838,9 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
|
||||
for (auto& op : child.getBody().front()) {
|
||||
auto newInst = rewriter.clone(op, mapper);
|
||||
if (auto weightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(newInst))
|
||||
if (auto weightedMvmOp = dyn_cast<spatial::SpatMVMOp>(newInst))
|
||||
remapWeightIndex(weightedMvmOp);
|
||||
if (auto weightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(newInst))
|
||||
if (auto weightedVmmOp = dyn_cast<spatial::SpatVMMOp>(newInst))
|
||||
remapWeightIndex(weightedVmmOp);
|
||||
}
|
||||
|
||||
@@ -884,9 +884,9 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
||||
ComputeMotifInfo& info = computeInfos[index];
|
||||
for (Operation& op : compute.getBody().front()) {
|
||||
info.instructionCount++;
|
||||
if (isa<spatial::SpatWeightedMVMOp>(&op))
|
||||
if (isa<spatial::SpatMVMOp>(&op))
|
||||
info.weightedMvmCount++;
|
||||
if (isa<spatial::SpatWeightedVMMOp>(&op))
|
||||
if (isa<spatial::SpatVMMOp>(&op))
|
||||
info.weightedVmmCount++;
|
||||
}
|
||||
if (info.weightedVmmCount > 0) {
|
||||
@@ -1617,13 +1617,13 @@ public:
|
||||
}
|
||||
|
||||
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
||||
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
||||
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
@@ -1643,22 +1643,22 @@ public:
|
||||
}
|
||||
|
||||
Operation* clonedOp = cpuRewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(&op)) {
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto newWeightedMvmOp = cast<spatial::SpatWeightedMVMOp>(clonedOp);
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&op)) {
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto newWeightedVmmOp = cast<spatial::SpatWeightedVMMOp>(clonedOp);
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
|
||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||
}
|
||||
|
||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) {
|
||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
RegularChunk chunk;
|
||||
chunk.startOp = startOp.getOperation();
|
||||
chunk.input = startOp.getInput();
|
||||
@@ -376,7 +376,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
|
||||
auto compactInBlock = [&](Block& block) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto startOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
||||
auto startOp = dyn_cast<spatial::SpatVMMOp>(&*it);
|
||||
if (!startOp) {
|
||||
++it;
|
||||
continue;
|
||||
@@ -391,7 +391,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
SmallVector<RegularChunk> run {*anchorChunk};
|
||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
while (runIt != block.end()) {
|
||||
auto candidateStart = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
||||
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!candidateStart)
|
||||
break;
|
||||
|
||||
@@ -425,7 +425,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto wvmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(&*it);
|
||||
auto wvmmOp = dyn_cast<spatial::SpatVMMOp>(&*it);
|
||||
if (!wvmmOp) {
|
||||
++it;
|
||||
continue;
|
||||
@@ -440,11 +440,11 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatWeightedVMMOp> run;
|
||||
SmallVector<spatial::SpatVMMOp> run;
|
||||
auto runIt = it;
|
||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatWeightedVMMOp>(&*runIt);
|
||||
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||
@@ -545,7 +545,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
extractOffsets,
|
||||
extractSizes,
|
||||
extractStrides);
|
||||
auto loopWvmm = spatial::SpatWeightedVMMOp::create(
|
||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||
|
||||
Reference in New Issue
Block a user