better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -141,152 +141,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
};
|
||||
|
||||
// Turns runtime constants consumed by compute regions into private globals and local loads.
|
||||
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
||||
Location loc = constantOp.getLoc();
|
||||
|
||||
if (hasWeightAlways(constantOp))
|
||||
return failure();
|
||||
|
||||
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
||||
return failure();
|
||||
|
||||
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
return false;
|
||||
if (isa<func::FuncOp>(op->getParentOp()))
|
||||
return true;
|
||||
return false;
|
||||
}))
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
||||
|
||||
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
||||
|
||||
if (constRankedTensorType) {
|
||||
mlir::MemRefType memRefType =
|
||||
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
||||
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
|
||||
loc,
|
||||
constantOp->getParentOfType<ModuleOp>(),
|
||||
"const",
|
||||
memRefType,
|
||||
constantOp.getValueAttr(),
|
||||
rewriter.getUnitAttr());
|
||||
std::string argName = globalOp.getSymName().str();
|
||||
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter,
|
||||
spatComputeBatch.getOperation(),
|
||||
BBArgIndex,
|
||||
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
|
||||
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||
Value hostConstant = constantOp.getResult();
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
|
||||
}
|
||||
else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
constUses.set(hostConstant);
|
||||
}
|
||||
else {
|
||||
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
||||
assert(batchParent && "Global Constant used direcly not within a compute");
|
||||
constUses.set(hostConstant);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (constantOp->use_empty())
|
||||
rewriter.eraseOp(constantOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
@@ -363,7 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user