promote weight inputs to actual weights in spat compute nodes
All checks were successful
Validate Operations / validate-operations (push) Successful in 17m36s
All checks were successful
Validate Operations / validate-operations (push) Successful in 17m36s
This commit is contained in:
@@ -53,6 +53,7 @@ private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -149,6 +150,12 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
encapsulateGlobalInstruction(*entryFunc);
|
||||
|
||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
mergeTriviallyConnectedComputes(*entryFunc);
|
||||
|
||||
// Dump to file for debug
|
||||
@@ -184,8 +191,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (llvm::any_of(
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
||||
llvm::SmallVector<Type> sourceTypes;
|
||||
llvm::SmallVector<Location> sourceLoc;
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
@@ -206,6 +213,63 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||
if (auto mapped = mapper.lookupOrNull(value))
|
||||
return cast<Value>(mapped);
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!tensorType || !tensorType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(tensorType.getRank());
|
||||
for (int64_t dim : tensorType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
|
||||
auto referencedValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
||||
mapper.map(value, referencedValue.getResult());
|
||||
return referencedValue.getResult();
|
||||
}
|
||||
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||
return failure();
|
||||
|
||||
IRMapping localMapper;
|
||||
for (Value operand : definingOp->getOperands()) {
|
||||
if (auto mapped = mapper.lookupOrNull(operand)) {
|
||||
localMapper.map(operand, cast<Value>(mapped));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isWeightLikeComputeOperand(operand)) {
|
||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||
if (failed(clonedOperand))
|
||||
return failure();
|
||||
localMapper.map(operand, *clonedOperand);
|
||||
continue;
|
||||
}
|
||||
|
||||
localMapper.map(operand, operand);
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
|
||||
auto mapped = mapper.lookupOrNull(value);
|
||||
if (!mapped)
|
||||
return failure();
|
||||
return cast<Value>(mapped);
|
||||
}
|
||||
|
||||
// TODO what we want to keep in global?
|
||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
@@ -328,6 +392,85 @@ void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>());
|
||||
|
||||
for (auto compute : computes) {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto& oldBlock = compute.getBody().front();
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (auto& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
compute.replaceAllUsesWith(newCompute);
|
||||
compute.erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user