multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
This commit is contained in:
@@ -182,7 +182,7 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
@@ -198,10 +198,10 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
}
|
||||
else {
|
||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||
@@ -219,7 +219,7 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
@@ -234,10 +234,10 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
}
|
||||
else {
|
||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||
|
||||
@@ -133,7 +133,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatWeightedCompute>(op))
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
@@ -167,16 +167,16 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||
Value source = funcSource(toRemoveOp);
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
|
||||
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
@@ -189,8 +189,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
auto sources = toRemoveOp.getInputs();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (llvm::any_of(
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
@@ -204,8 +204,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
@@ -298,14 +298,15 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
|
||||
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
|
||||
SmallVector<spatial::SpatCompute> trivialComputes;
|
||||
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
||||
if (compute->hasOneUse()) {
|
||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
auto& use = *compute->getUses().begin();
|
||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||
|
||||
if (user && user.getInputs().size() == 1)
|
||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||
trivialComputes.push_back(compute);
|
||||
}
|
||||
|
||||
@@ -317,12 +318,15 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
trivialComputes.pop_back();
|
||||
continue;
|
||||
}
|
||||
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
auto& computeUse = *compute->getUses().begin();
|
||||
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
|
||||
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||
|
||||
@@ -343,7 +347,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
|
||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
||||
newTerminator->erase();
|
||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||
for (auto& op : child.getBody().front()) {
|
||||
@@ -371,14 +375,16 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
toErase.insert(compute);
|
||||
|
||||
if (newCompute->hasOneUse()) {
|
||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
|
||||
if (user && user.getInputs().size() == 1)
|
||||
auto& use = *newCompute->getUses().begin();
|
||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||
trivialComputes.push_back(newCompute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : toErase) {
|
||||
compute.getResult(0).dropAllUses();
|
||||
for (Value result : compute->getResults())
|
||||
result.dropAllUses();
|
||||
compute.erase();
|
||||
}
|
||||
}
|
||||
@@ -386,7 +392,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
||||
if (isAlwaysWeight)
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
@@ -394,7 +400,7 @@ void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
|
||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>());
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
|
||||
for (auto compute : computes) {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
@@ -430,7 +436,7 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
|
||||
@@ -147,33 +147,37 @@ static Value buildPackedBias(bool hasBias,
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
}
|
||||
|
||||
static Value createIm2colCompute(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType rowType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType im2colRowType,
|
||||
RankedTensorType gemmInputRowType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto elemType = xType.getElementType();
|
||||
constexpr size_t numInputs = 1;
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
@@ -240,7 +244,7 @@ static Value createIm2colCompute(Value x,
|
||||
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
im2colRowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
@@ -256,121 +260,115 @@ static Value createIm2colCompute(Value x,
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
});
|
||||
return im2colComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createPackedIm2colRows(Value im2col,
|
||||
RankedTensorType im2colType,
|
||||
Type elemType,
|
||||
int64_t numPatches,
|
||||
int64_t patchSize,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return im2col;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
|
||||
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
|
||||
});
|
||||
return packedComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createUnpackedOutput(Value packedOutput,
|
||||
RankedTensorType gemmOutType,
|
||||
RankedTensorType outType,
|
||||
int64_t numPatches,
|
||||
int64_t numChannelsOut,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return packedOutput;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedOutputArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
|
||||
Value unpackedOutput = paddedOutput;
|
||||
if (paddedNumPatches != numPatches) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
unpackedOutput =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
|
||||
SmallVector<Value> rowResults;
|
||||
rowResults.reserve(packedNumRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
rowResults.push_back(
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
|
||||
});
|
||||
return unpackComputeOp.getResult(0);
|
||||
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(im2colComputeOp.getNumResults());
|
||||
for (Value result : im2colComputeOp.getResults())
|
||||
rows.push_back(result);
|
||||
return rows;
|
||||
}
|
||||
|
||||
static Value createCollectedConvOutput(Value gemmOut,
|
||||
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
Type convType,
|
||||
RankedTensorType gemmOutType,
|
||||
RankedTensorType nhwcType,
|
||||
RankedTensorType outType,
|
||||
int64_t numPatches,
|
||||
int64_t numChannelsOut,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto collectComputeOp =
|
||||
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
|
||||
Value gemmOutArg = gemmOutArgs.front();
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOutArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||
Value gemmOut;
|
||||
if (packFactor == 1) {
|
||||
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
}
|
||||
else {
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
Value packedOutput = gemmRowArgs.size() == 1
|
||||
? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
|
||||
gemmOut = paddedOutput;
|
||||
if (paddedNumPatches != numPatches) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOut,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
});
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
@@ -487,11 +485,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
Value biasMatrix;
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmC = b;
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
@@ -500,94 +498,89 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t effectiveMaxParallelPixels =
|
||||
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
||||
|
||||
Value im2col = createIm2colCompute(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
rewriter,
|
||||
loc);
|
||||
// Keep the standard im2col view of convolution:
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
||||
//
|
||||
// We want to process N pixels at the same time. Instead of doing N separate operations
|
||||
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||
// B_packed: [N * patchSize, N * cOut]
|
||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||
auto gemmOutputRowType =
|
||||
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
gemmInputRowType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
|
||||
Value gemmOut;
|
||||
if (effectiveMaxParallelPixels == 1) {
|
||||
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
|
||||
gemmOut = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutType,
|
||||
im2col,
|
||||
wTrans,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
}
|
||||
else {
|
||||
// Keep the standard im2col view of convolution:
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// but repack several old rows into one new row so we use the available crossbar size better.
|
||||
//
|
||||
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
|
||||
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||
// B_packed: [N * patchSize, N * cOut]
|
||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||
auto packedOutType =
|
||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
Value gemmB = buildPackedWeight(wDenseAttr,
|
||||
wTrans,
|
||||
wType,
|
||||
numChannelsIn,
|
||||
numChannelsOut,
|
||||
wHeight,
|
||||
wWidth,
|
||||
patchSize,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
Value gemmC = buildPackedBias(
|
||||
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
|
||||
Value packedA = createPackedIm2colRows(
|
||||
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
|
||||
Value packedB = buildPackedWeight(wDenseAttr,
|
||||
wTrans,
|
||||
wType,
|
||||
numChannelsIn,
|
||||
numChannelsOut,
|
||||
wHeight,
|
||||
wWidth,
|
||||
patchSize,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
Value packedC = buildPackedBias(
|
||||
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
Value packedOut = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
packedOutType,
|
||||
packedA,
|
||||
packedB,
|
||||
packedC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
gemmOut = createUnpackedOutput(
|
||||
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(gemmInputRows.size());
|
||||
for (Value gemmInputRow : gemmInputRows) {
|
||||
Value gemmRow = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutputRowType,
|
||||
gemmInputRow,
|
||||
gemmB,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
gemmRows.push_back(gemmRow);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(gemmRows,
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -42,15 +42,15 @@ private:
|
||||
raw_ostream& os;
|
||||
|
||||
/**
|
||||
* Draws the subgraph for a given spatial::SpatWeightedCompute, including:
|
||||
* Draws the subgraph for a given spatial::SpatCompute, including:
|
||||
* 1. Input nodes (block arguments)
|
||||
* 2. Operations
|
||||
* 3. Edges between yield (output) and its users
|
||||
*
|
||||
* @param op The spatial::SpatWeightedCompute to draw the subgraph for.
|
||||
* @param op The spatial::SpatCompute to draw the subgraph for.
|
||||
* @param computeNum The number of the compute operation.
|
||||
*/
|
||||
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
|
||||
void drawComputeOpSubgraph(spatial::SpatCompute op, size_t computeNum) {
|
||||
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
|
||||
<< "\t\tstyle=filled;\n"
|
||||
<< "\t\tcolor=lightblue;\n";
|
||||
@@ -217,7 +217,7 @@ void SpatialToGraphvizPass::runOnOperation() {
|
||||
// 1. Print their subgraph
|
||||
// 2. Print the edges from its inputs to its outputs
|
||||
for (Operation& op : func.getOps()) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
drawComputeOpSubgraph(computeOp, computeNum++);
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
|
||||
@@ -62,7 +62,7 @@ private:
|
||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||
void
|
||||
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
Value consumerValue,
|
||||
@@ -73,7 +73,7 @@ private:
|
||||
void annotateChannelCoreIds(func::FuncOp funcOp);
|
||||
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
@@ -116,7 +116,7 @@ static size_t countComputeLeafUsers(Value value) {
|
||||
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
||||
for (OpOperand& use : currentValue.getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (isa<spatial::SpatWeightedCompute>(owner)) {
|
||||
if (isa<spatial::SpatCompute>(owner)) {
|
||||
leafUserCount++;
|
||||
continue;
|
||||
}
|
||||
@@ -174,7 +174,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
markOpToRemove(receiveOp);
|
||||
runOnReceiveOp(receiveOp, rewriter);
|
||||
}
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
runOnComputeOp(computeOp, rewriter);
|
||||
}
|
||||
@@ -222,7 +222,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
auto& block = computeOp.getRegion().front();
|
||||
@@ -504,7 +504,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
|
||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||
for (auto& op : funcOp.getBody().getOps())
|
||||
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
unsigned numComputeWeights = computeOp.getWeights().size();
|
||||
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
|
||||
TypedValue<TensorType> tensorSource;
|
||||
@@ -513,7 +513,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||
|
||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||
@@ -538,7 +538,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
|
||||
|
||||
// Compute results must be transferred through channels via send/receive
|
||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
|
||||
@@ -553,7 +553,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
Value consumerValue,
|
||||
@@ -614,7 +614,7 @@ void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
||||
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
||||
for (OpOperand& use : currentValue.getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
|
||||
if (auto computeUser = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
replaceBlockArgumentWithRecvOp(
|
||||
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
||||
continue;
|
||||
|
||||
Reference in New Issue
Block a user