multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 1h2m3s

This commit is contained in:
NiccoloN
2026-04-22 18:29:06 +02:00
parent 0f13269040
commit 87922d994f
16 changed files with 403 additions and 396 deletions

View File

@@ -182,7 +182,7 @@ auto createSpatCompute(RewriterT& rewriter,
mlir::ValueRange inputs, mlir::ValueRange inputs,
BodyFn&& body) { BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs); auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block(); auto* block = new mlir::Block();
for (mlir::Value input : inputs) for (mlir::Value input : inputs)
@@ -198,10 +198,10 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure()); return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
} }
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp); return mlir::FailureOr<spatial::SpatCompute>(computeOp);
} }
else { else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult"); static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
@@ -219,7 +219,7 @@ auto createSpatCompute(RewriterT& rewriter,
mlir::ValueRange weights, mlir::ValueRange weights,
mlir::ValueRange inputs, mlir::ValueRange inputs,
BodyFn&& body) { BodyFn&& body) {
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs); auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block(); auto* block = new mlir::Block();
for (mlir::Value input : inputs) for (mlir::Value input : inputs)
@@ -234,10 +234,10 @@ auto createSpatCompute(RewriterT& rewriter,
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure()); return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
} }
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp); return mlir::FailureOr<spatial::SpatCompute>(computeOp);
} }
else { else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult"); static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");

View File

@@ -133,7 +133,7 @@ void ONNXToSpatialPass::runOnOperation() {
if (coresCount != -1) { if (coresCount != -1) {
int computeOpsCount = 0; int computeOpsCount = 0;
for (auto& op : entryFunc->getFunctionBody().front().getOperations()) for (auto& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatWeightedCompute>(op)) if (isa<spatial::SpatCompute>(op))
computeOpsCount++; computeOpsCount++;
if (computeOpsCount > coresCount) { if (computeOpsCount > coresCount) {
@@ -167,16 +167,16 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) { if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp); Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) { if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source); auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB); rewriter.setInsertionPointToEnd(BB);
IRMapping mapper; IRMapping mapper;
mapper.map(source, BB->getArgument(0)); mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper); auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0)); spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute); inst->replaceAllUsesWith(newCompute->getResults());
inst->erase(); inst->erase();
return true; return true;
} }
@@ -189,8 +189,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
auto sources = toRemoveOp.getInputs(); auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of( if (llvm::any_of(
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) { sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources); auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes; SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc; SmallVector<Location> sourceLoc;
for (auto source : sources) { for (auto source : sources) {
@@ -204,8 +204,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg); mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper); auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0)); spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute); inst->replaceAllUsesWith(newCompute->getResults());
inst->erase(); inst->erase();
return true; return true;
} }
@@ -298,14 +298,15 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatWeightedCompute> trivialComputes; SmallVector<spatial::SpatCompute> trivialComputes;
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase; llvm::SmallSet<spatial::SpatCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>()) for (auto compute : funcOp.getOps<spatial::SpatCompute>())
if (compute->hasOneUse()) { if (compute->hasOneUse()) {
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin()); auto& use = *compute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1) if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(compute); trivialComputes.push_back(compute);
} }
@@ -317,12 +318,15 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
trivialComputes.pop_back(); trivialComputes.pop_back();
continue; continue;
} }
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin()); auto& computeUse = *compute->getUses().begin();
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
rewriter.setInsertionPointAfter(compute.getOperation()); rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute = auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())}); {static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
@@ -343,7 +347,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator(); auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0)); mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
newTerminator->erase(); newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) { for (auto& op : child.getBody().front()) {
@@ -371,14 +375,16 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
toErase.insert(compute); toErase.insert(compute);
if (newCompute->hasOneUse()) { if (newCompute->hasOneUse()) {
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin()); auto& use = *newCompute->getUses().begin();
if (user && user.getInputs().size() == 1) auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(newCompute); trivialComputes.push_back(newCompute);
} }
} }
for (auto compute : toErase) { for (auto compute : toErase) {
compute.getResult(0).dropAllUses(); for (Value result : compute->getResults())
result.dropAllUses();
compute.erase(); compute.erase();
} }
} }
@@ -386,7 +392,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) { funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight = bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); }); llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
if (isAlwaysWeight) if (isAlwaysWeight)
markWeightAlways(constantOp); markWeightAlways(constantOp);
}); });
@@ -394,7 +400,7 @@ void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) { LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>()); SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
for (auto compute : computes) { for (auto compute : computes) {
SmallVector<bool> promoteInput(compute.getInputs().size(), false); SmallVector<bool> promoteInput(compute.getInputs().size(), false);
@@ -430,7 +436,7 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
} }
auto newCompute = auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock = auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(

View File

@@ -147,33 +147,37 @@ static Value buildPackedBias(bool hasBias,
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
} }
static Value createIm2colCompute(Value x, static SmallVector<Value> createIm2colRowComputes(Value x,
RankedTensorType xType, RankedTensorType xType,
RankedTensorType im2colType, RankedTensorType im2colType,
RankedTensorType rowType, RankedTensorType im2colRowType,
int64_t batchSize, RankedTensorType gemmInputRowType,
int64_t numChannelsIn, int64_t batchSize,
int64_t xHeight, int64_t numChannelsIn,
int64_t xWidth, int64_t xHeight,
int64_t wHeight, int64_t xWidth,
int64_t wWidth, int64_t wHeight,
int64_t padHeightBegin, int64_t wWidth,
int64_t padHeightEnd, int64_t padHeightBegin,
int64_t padWidthBegin, int64_t padHeightEnd,
int64_t padWidthEnd, int64_t padWidthBegin,
int64_t strideHeight, int64_t padWidthEnd,
int64_t strideWidth, int64_t strideHeight,
int64_t dilationHeight, int64_t strideWidth,
int64_t dilationWidth, int64_t dilationHeight,
int64_t outWidth, int64_t dilationWidth,
int64_t patchSize, int64_t outWidth,
int64_t numPatches, int64_t patchSize,
int64_t numPatchesPerBatch, int64_t numPatches,
ConversionPatternRewriter& rewriter, int64_t numPatchesPerBatch,
Location loc) { int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = xType.getElementType(); auto elemType = xType.getElementType();
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) { const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
Value paddedInput = xArg; Value paddedInput = xArg;
// Pad input with zeros if needed: // Pad input with zeros if needed:
@@ -240,7 +244,7 @@ static Value createIm2colCompute(Value x,
Value row = tensor::CollapseShapeOp::create(rewriter, Value row = tensor::CollapseShapeOp::create(rewriter,
loc, loc,
rowType, im2colRowType,
patch, patch,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0}, {0},
@@ -256,121 +260,117 @@ static Value createIm2colCompute(Value x,
rewriter.setInsertionPointAfter(im2colLoop); rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0); Value im2col = im2colLoop.getResult(0);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
});
return im2colComputeOp.getResult(0);
}
static Value createPackedIm2colRows(Value im2col, Value gemmInputRows = im2col;
RankedTensorType im2colType, if (packFactor != 1) {
Type elemType, const int64_t paddedNumPatches = packedNumRows * packFactor;
int64_t numPatches, auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
int64_t patchSize, auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
int64_t packFactor, Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
ConversionPatternRewriter& rewriter, Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
Location loc) { loc,
if (packFactor == 1) groupedType,
return im2col; paddedIm2col,
SmallVector<ReassociationIndices> {
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); {0, 1},
const int64_t paddedNumPatches = packedNumRows * packFactor; {2}
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); });
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) { loc,
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc); packedType,
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, groupedIm2col,
loc, SmallVector<ReassociationIndices> {
groupedType, {0},
paddedIm2col, {1, 2}
SmallVector<ReassociationIndices> { });
{0, 1},
{2}
});
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
});
return packedComputeOp.getResult(0);
}
static Value createUnpackedOutput(Value packedOutput,
RankedTensorType gemmOutType,
RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return packedOutput;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutputArg,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedOutput,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value unpackedOutput = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
unpackedOutput =
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
} }
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput); SmallVector<Value> rowResults;
rowResults.reserve(packedNumRows);
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
rowResults.push_back(
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
}
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
}); });
return unpackComputeOp.getResult(0);
SmallVector<Value> rows;
rows.reserve(im2colComputeOp.getNumResults());
for (Value result : im2colComputeOp.getResults())
rows.push_back(result);
return rows;
} }
static Value createCollectedConvOutput(Value gemmOut, static Value createCollectedConvOutput(ValueRange gemmRows,
Type convType, Type convType,
RankedTensorType gemmOutType,
RankedTensorType nhwcType, RankedTensorType nhwcType,
RankedTensorType outType, RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
auto collectComputeOp = const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) { const int64_t paddedNumPatches = packedNumRows * packFactor;
Value gemmOutArg = gemmOutArgs.front(); auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
Value gemmOut;
// Restore to NCHW layout: if (packFactor == 1) {
// [numPatches, numChannelsOut] gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
// -> [1, outHeight, outWidth, numChannelsOut] : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
// -> [1, numChannelsOut, outHeight, outWidth] }
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, else {
loc, auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
nhwcType, auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
gemmOutArg, Value packedOutput =
SmallVector<ReassociationIndices> { gemmRowArgs.size() == 1
{0, 1, 2}, ? gemmRowArgs.front()
{3} : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutput,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
}); });
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
spatial::SpatYieldOp::create(rewriter, loc, nchwOut); loc,
paddedType,
expandedOutput,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmOut = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
}
}
// Restore to NCHW layout:
// [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut]
// -> [1, numChannelsOut, outHeight, outWidth]
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc,
nhwcType,
gemmOut,
SmallVector<ReassociationIndices> {
{0, 1, 2},
{3}
}); });
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
});
return collectComputeOp.getResult(0); return collectComputeOp.getResult(0);
} }
@@ -487,11 +487,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// Pass bias through directly; Gemm handles rank-1 C canonicalization. // Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp()); bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
Value biasMatrix; Value biasMatrix;
DenseElementsAttr biasDenseAttr; DenseElementsAttr biasDenseAttr;
if (hasB) { if (hasB) {
gemmC = b; gemmBias = b;
biasDenseAttr = getDenseConstantAttr(b); biasDenseAttr = getDenseConstantAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc); biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
} }
@@ -500,94 +500,86 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t effectiveMaxParallelPixels = const int64_t effectiveMaxParallelPixels =
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1; (canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
Value im2col = createIm2colCompute(x, // Keep the standard im2col view of convolution:
xType, // A (im2col): [numPatches, patchSize] -- one row per output spatial position
im2colType, // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
rowType, // and optionally repack several old rows into one GEMM row to use the available crossbar size better.
batchSize, //
numChannelsIn, // The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
xHeight, // the row it needs instead of receiving a full packed tensor and slicing it locally.
xWidth, auto gemmInputRowType =
wHeight, RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
wWidth, auto gemmOutputRowType =
padHeightBegin, RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
padHeightEnd, SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
padWidthBegin, xType,
padWidthEnd, im2colType,
strideHeight, rowType,
strideWidth, gemmInputRowType,
dilationHeight, batchSize,
dilationWidth, numChannelsIn,
outWidth, xHeight,
patchSize, xWidth,
numPatches, wHeight,
numPatchesPerBatch, wWidth,
rewriter, padHeightBegin,
loc); padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter,
loc);
Value gemmOut; Value gemmB = buildPackedWeight(wDenseAttr,
if (effectiveMaxParallelPixels == 1) { wTrans,
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels. wType,
gemmOut = ONNXGemmOp::create(rewriter, numChannelsIn,
loc, numChannelsOut,
gemmOutType, wHeight,
im2col, wWidth,
wTrans, patchSize,
gemmC, effectiveMaxParallelPixels,
rewriter.getF32FloatAttr(1.0f), rewriter,
rewriter.getF32FloatAttr(1.0f), loc);
rewriter.getBoolAttr(false), Value gemmC = buildPackedBias(
rewriter.getBoolAttr(false)) hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
.getY();
}
else {
// Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// but repack several old rows into one new row so we use the available crossbar size better.
//
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto packedOutType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value packedA = createPackedIm2colRows( SmallVector<Value> gemmRows;
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc); gemmRows.reserve(gemmInputRows.size());
Value packedB = buildPackedWeight(wDenseAttr, for (Value gemmInputRow : gemmInputRows) {
wTrans, Value gemmRow = ONNXGemmOp::create(rewriter,
wType, loc,
numChannelsIn, gemmOutputRowType,
numChannelsOut, gemmInputRow,
wHeight, gemmB,
wWidth, gemmC,
patchSize, rewriter.getF32FloatAttr(1.0f),
effectiveMaxParallelPixels, rewriter.getF32FloatAttr(1.0f),
rewriter, rewriter.getBoolAttr(false),
loc); rewriter.getBoolAttr(false))
Value packedC = buildPackedBias( .getY();
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); gemmRows.push_back(gemmRow);
Value packedOut = ONNXGemmOp::create(rewriter,
loc,
packedOutType,
packedA,
packedB,
packedC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
gemmOut = createUnpackedOutput(
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
} }
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc)); rewriter.replaceOp(convOp,
createCollectedConvOutput(gemmRows,
convOp.getType(),
gemmOutType,
nhwcType,
outType,
numPatches,
numChannelsOut,
effectiveMaxParallelPixels,
rewriter,
loc));
return success(); return success();
} }

View File

@@ -42,15 +42,15 @@ private:
raw_ostream& os; raw_ostream& os;
/** /**
* Draws the subgraph for a given spatial::SpatWeightedCompute, including: * Draws the subgraph for a given spatial::SpatCompute, including:
* 1. Input nodes (block arguments) * 1. Input nodes (block arguments)
* 2. Operations * 2. Operations
* 3. Edges between yield (output) and its users * 3. Edges between yield (output) and its users
* *
* @param op The spatial::SpatWeightedCompute to draw the subgraph for. * @param op The spatial::SpatCompute to draw the subgraph for.
* @param computeNum The number of the compute operation. * @param computeNum The number of the compute operation.
*/ */
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) { void drawComputeOpSubgraph(spatial::SpatCompute op, size_t computeNum) {
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n" os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
<< "\t\tstyle=filled;\n" << "\t\tstyle=filled;\n"
<< "\t\tcolor=lightblue;\n"; << "\t\tcolor=lightblue;\n";
@@ -217,7 +217,7 @@ void SpatialToGraphvizPass::runOnOperation() {
// 1. Print their subgraph // 1. Print their subgraph
// 2. Print the edges from its inputs to its outputs // 2. Print the edges from its inputs to its outputs
for (Operation& op : func.getOps()) { for (Operation& op : func.getOps()) {
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) { if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
drawComputeOpSubgraph(computeOp, computeNum++); drawComputeOpSubgraph(computeOp, computeNum++);
} }
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) { else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {

View File

@@ -62,7 +62,7 @@ private:
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
void void
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter); addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
unsigned int argIndex, unsigned int argIndex,
Value channelSourceOp, Value channelSourceOp,
Value consumerValue, Value consumerValue,
@@ -73,7 +73,7 @@ private:
void annotateChannelCoreIds(func::FuncOp funcOp); void annotateChannelCoreIds(func::FuncOp funcOp);
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter); void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter); void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
@@ -116,7 +116,7 @@ static size_t countComputeLeafUsers(Value value) {
auto walkUses = [&](Value currentValue, auto& self) -> void { auto walkUses = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) { for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner(); Operation* owner = use.getOwner();
if (isa<spatial::SpatWeightedCompute>(owner)) { if (isa<spatial::SpatCompute>(owner)) {
leafUserCount++; leafUserCount++;
continue; continue;
} }
@@ -174,7 +174,7 @@ void SpatialToPimPass::runOnOperation() {
markOpToRemove(receiveOp); markOpToRemove(receiveOp);
runOnReceiveOp(receiveOp, rewriter); runOnReceiveOp(receiveOp, rewriter);
} }
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) { for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp); markOpToRemove(computeOp);
runOnComputeOp(computeOp, rewriter); runOnComputeOp(computeOp, rewriter);
} }
@@ -222,7 +222,7 @@ void SpatialToPimPass::runOnOperation() {
dumpModule(moduleOp, "pim0"); dumpModule(moduleOp, "pim0");
} }
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc(); Location loc = computeOp->getLoc();
auto& block = computeOp.getRegion().front(); auto& block = computeOp.getRegion().front();
@@ -504,7 +504,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove; llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
for (auto& op : funcOp.getBody().getOps()) for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) { if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
unsigned numComputeWeights = computeOp.getWeights().size(); unsigned numComputeWeights = computeOp.getWeights().size();
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
TypedValue<TensorType> tensorSource; TypedValue<TensorType> tensorSource;
@@ -513,7 +513,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) { if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource()); tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp())) if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue; continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape(); ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
@@ -538,7 +538,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
tensorSource = cast<TypedValue<TensorType>>(computeOpInput); tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
// Compute results must be transferred through channels via send/receive // Compute results must be transferred through channels via send/receive
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp())) if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue; continue;
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
@@ -553,7 +553,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
return success(); return success();
} }
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
unsigned int argIndex, unsigned int argIndex,
Value channelSourceOp, Value channelSourceOp,
Value consumerValue, Value consumerValue,
@@ -614,7 +614,7 @@ void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void { auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
for (OpOperand& use : currentValue.getUses()) { for (OpOperand& use : currentValue.getUses()) {
Operation* owner = use.getOwner(); Operation* owner = use.getOwner();
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) { if (auto computeUser = dyn_cast<spatial::SpatCompute>(owner)) {
replaceBlockArgumentWithRecvOp( replaceBlockArgumentWithRecvOp(
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter); computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
continue; continue;

View File

@@ -32,7 +32,7 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
// Execution // Execution
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute region with attached constant weights"; let summary = "Compute region with attached constant weights";
let arguments = (ins let arguments = (ins

View File

@@ -119,7 +119,7 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
} }
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) { llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp()); auto wcomputeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp());
if (wcomputeOp) if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape(); return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
@@ -134,7 +134,7 @@ llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigth
LogicalResult SpatWeightedMVMOp::verify() { LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt)) if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op"); return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt; auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape(); auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape(); auto outputShape = getOutput().getType().getShape();
@@ -155,7 +155,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
LogicalResult SpatWeightedVMMOp::verify() { LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt)) if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op"); return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt; auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape(); auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape(); auto outputShape = getOutput().getType().getShape();
@@ -200,9 +200,8 @@ LogicalResult SpatVMaxOp::verify() {
return OpTrait::impl::verifySameOperandsAndResultType(*this); return OpTrait::impl::verifySameOperandsAndResultType(*this);
} }
LogicalResult SpatWeightedCompute::verify() { LogicalResult SpatCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single // Check that the terminator yields the same number and types as the compute results.
// operand with the same type as the result
auto& block = getBody().front(); auto& block = getBody().front();
if (block.mightHaveTerminator()) { if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator()); auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
@@ -257,7 +256,7 @@ LogicalResult SpatWeightedCompute::verify() {
return success(); return success();
} }
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front(); Block& block = getBody().front();
if (!llvm::hasSingleElement(block)) if (!llvm::hasSingleElement(block))
return failure(); return failure();

View File

@@ -74,15 +74,15 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
return aggregatedEdges; return aggregatedEdges;
} }
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes, VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges) { llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph graph; VirtualGraph graph;
graph.nodes.reserve(spatWeightedComputes.size()); graph.nodes.reserve(spatComputes.size());
for (auto [index, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) { for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
VirtualNode node; VirtualNode node;
node.originalComputeIndices.push_back(index); node.originalComputeIndices.push_back(index);
node.weight = getSpatComputeWeight(spatWeightedCompute); node.weight = getSpatComputeWeight(spatCompute);
node.crossbarUsage = getSpatComputeCrossbarUsage(spatWeightedCompute); node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute);
graph.nodes.push_back(std::move(node)); graph.nodes.push_back(std::move(node));
} }
graph.edges = aggregateEdges(edges); graph.edges = aggregateEdges(edges);
@@ -344,22 +344,22 @@ std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::A
} }
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes, llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> originalEdges) { llvm::ArrayRef<IndexedEdge> originalEdges) {
DCPAnalysisResult result; DCPAnalysisResult result;
std::vector<size_t> originalToVirtualNode(spatWeightedComputes.size(), 0); std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes)) for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
for (size_t originalIndex : virtualNode.originalComputeIndices) for (size_t originalIndex : virtualNode.originalComputeIndices)
originalToVirtualNode[originalIndex] = virtualNodeIndex; originalToVirtualNode[originalIndex] = virtualNodeIndex;
auto dominanceOrder = computeOriginalTopologicalOrder(spatWeightedComputes.size(), originalEdges); auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges);
result.dominanceOrderCompute.reserve(dominanceOrder.size()); result.dominanceOrderCompute.reserve(dominanceOrder.size());
for (size_t originalIndex : dominanceOrder) { for (size_t originalIndex : dominanceOrder) {
SpatWeightedCompute spatWeightedCompute = spatWeightedComputes[originalIndex]; SpatCompute spatCompute = spatComputes[originalIndex];
size_t cpu = originalToVirtualNode[originalIndex]; size_t cpu = originalToVirtualNode[originalIndex];
result.dominanceOrderCompute.push_back(spatWeightedCompute); result.dominanceOrderCompute.push_back(spatCompute);
result.computeToCpuMap[spatWeightedCompute] = cpu; result.computeToCpuMap[spatCompute] = cpu;
result.cpuToLastComputeMap[cpu] = spatWeightedCompute; result.cpuToLastComputeMap[cpu] = spatCompute;
} }
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap) for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
@@ -367,10 +367,10 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
return result; return result;
} }
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes, DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges, llvm::ArrayRef<IndexedEdge> edges,
MLIRContext* context) { MLIRContext* context) {
GraphDCP graphDCP(spatWeightedComputes, edges); GraphDCP graphDCP(spatComputes, edges);
if (coresCount.getValue() > 0) if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue())); graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
graphDCP.setContext(context); graphDCP.setContext(context);
@@ -380,7 +380,7 @@ DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedC
} // namespace } // namespace
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) { SpatCompute getOriginalSpatCompute(Operation* op) {
if (!op) if (!op)
return {}; return {};
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) { while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
@@ -388,39 +388,33 @@ SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
if (!op) if (!op)
return {}; return {};
} }
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op)) if (auto res = llvm::dyn_cast<SpatCompute>(op))
return res; return res;
return {}; return {};
} }
DCPAnalysisResult DCPAnalysis::run() { DCPAnalysisResult DCPAnalysis::run() {
SmallVector<SpatWeightedCompute, 10> spatWeightedComputes; SmallVector<SpatCompute, 10> spatComputes;
SmallVector<IndexedEdge, 10> edges; SmallVector<IndexedEdge, 10> edges;
for (auto& region : entryOp->getRegions()) for (auto& region : entryOp->getRegions())
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>()) for (SpatCompute spatCompute : region.getOps<SpatCompute>())
spatWeightedComputes.push_back(spatWeightedCompute); spatComputes.push_back(spatCompute);
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) { for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) {
for (Value input : spatWeightedCompute.getInputs()) { for (Value input : spatCompute.getInputs()) {
if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) { if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) {
auto producerIt = llvm::find(spatWeightedComputes, producerCompute); auto producerIt = llvm::find(spatComputes, producerCompute);
assert(producerIt != spatWeightedComputes.end()); assert(producerIt != spatComputes.end());
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt); auto indexStartEdge = std::distance(spatComputes.begin(), producerIt);
ResultRange outputs = producerCompute.getResults(); edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast<ShapedType>(input.getType()))});
int64_t totalSize = 0;
for (auto output : outputs) {
ShapedType resultType = cast<ShapedType>(output.getType());
totalSize += getSizeInBytes(resultType);
}
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
} }
} }
} }
if (dcpCriticalWindowSize.getValue() == 0) if (dcpCriticalWindowSize.getValue() == 0)
return runLegacyDcp(spatWeightedComputes, edges, entryOp->getContext()); return runLegacyDcp(spatComputes, edges, entryOp->getContext());
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatWeightedComputes, edges); VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
std::set<std::vector<size_t>> seenCriticalWindows; std::set<std::vector<size_t>> seenCriticalWindows;
while (virtualGraph.nodes.size() > 1) { while (virtualGraph.nodes.size() > 1) {
TimingInfo timing = computeTiming(virtualGraph); TimingInfo timing = computeTiming(virtualGraph);
@@ -446,7 +440,7 @@ DCPAnalysisResult DCPAnalysis::run() {
break; break;
} }
return buildResultFromVirtualGraph(virtualGraph, spatWeightedComputes, edges); return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges);
} }
} // namespace spatial } // namespace spatial

View File

@@ -10,10 +10,10 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
struct DCPAnalysisResult { struct DCPAnalysisResult {
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute; std::vector<onnx_mlir::spatial::SpatCompute> dominanceOrderCompute;
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCpuMap; llvm::DenseMap<onnx_mlir::spatial::SpatCompute, size_t> computeToCpuMap;
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfCpu; llvm::DenseSet<onnx_mlir::spatial::SpatCompute> isLastComputeOfCpu;
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap; llvm::DenseMap<size_t, onnx_mlir::spatial::SpatCompute> cpuToLastComputeMap;
}; };
namespace onnx_mlir { namespace onnx_mlir {

View File

@@ -1260,7 +1260,7 @@ DCPAnalysisResult GraphDCP::getResult() {
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size()); auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
ret.dominanceOrderCompute.reserve(dominanceOrder.size()); ret.dominanceOrderCompute.reserve(dominanceOrder.size());
for (auto elem : dominanceOrder) for (auto elem : dominanceOrder)
ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute()); ret.dominanceOrderCompute.push_back(elem->getSpatCompute());
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) { for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
const CpuTaskList* tasks = findCpuTasks(cpu); const CpuTaskList* tasks = findCpuTasks(cpu);
@@ -1268,10 +1268,10 @@ DCPAnalysisResult GraphDCP::getResult() {
continue; continue;
size_t i = 0; size_t i = 0;
for (auto node : *tasks) { for (auto node : *tasks) {
ret.computeToCpuMap[node->getSpatWeightedCompute()] = cpu; ret.computeToCpuMap[node->getSpatCompute()] = cpu;
if (i++ == tasks->size() - 1) { if (i++ == tasks->size() - 1) {
ret.isLastComputeOfCpu.insert(node->getSpatWeightedCompute()); ret.isLastComputeOfCpu.insert(node->getSpatCompute());
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute(); ret.cpuToLastComputeMap[cpu] = node->getSpatCompute();
} }
} }
} }

View File

@@ -115,11 +115,11 @@ private:
public: public:
void runDcp(); void runDcp();
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes, GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges) llvm::ArrayRef<IndexedEdge> edges)
: nodes(), cpuTasks(), cpuCrossbarUsage() { : nodes(), cpuTasks(), cpuCrossbarUsage() {
for (auto spatWeightedCompute : spatWeightedComputes) for (auto spatCompute : spatComputes)
nodes.emplace_back(spatWeightedCompute); nodes.emplace_back(spatCompute);
for (auto [start, end, weight] : edges) for (auto [start, end, weight] : edges)
makeEdge(start, end, weight); makeEdge(start, end, weight);
} }

View File

@@ -8,7 +8,7 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> { class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute; onnx_mlir::spatial::SpatCompute spatCompute;
Time aest; Time aest;
Time alst; Time alst;
std::optional<CPU> scheduledCpu; std::optional<CPU> scheduledCpu;
@@ -38,22 +38,22 @@ public:
std::vector<Edge> parents; std::vector<Edge> parents;
std::vector<Edge> children; std::vector<Edge> children;
TaskDCP() = default; TaskDCP() = default;
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) TaskDCP(onnx_mlir::spatial::SpatCompute spatCompute)
: onnx_mlir::LabeledListNode<TaskDCP>(), : onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(spatWeightedCompute), spatCompute(spatCompute),
aest(0), aest(0),
alst(0), alst(0),
scheduledCpu(), scheduledCpu(),
weight(getSpatComputeWeight(spatWeightedCompute)), weight(getSpatComputeWeight(spatCompute)),
baseWeight(weight), baseWeight(weight),
crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)), crossbarUsage(getSpatComputeCrossbarUsage(spatCompute)),
syntheticId(-1), syntheticId(-1),
parents(), parents(),
children() {} children() {}
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0) TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
: onnx_mlir::LabeledListNode<TaskDCP>(), : onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(), spatCompute(),
aest(0), aest(0),
alst(0), alst(0),
scheduledCpu(), scheduledCpu(),
@@ -90,14 +90,14 @@ public:
void setAlst(Time value) { alst = value; } void setAlst(Time value) { alst = value; }
bool hasDescendant(TaskDCP* child); bool hasDescendant(TaskDCP* child);
int64_t Id() const { int64_t Id() const {
if (spatWeightedCompute) if (spatCompute)
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer()); return reinterpret_cast<int64_t>(spatCompute.getAsOpaquePointer());
return syntheticId; return syntheticId;
} }
bool isCriticalPath() const { return alst == aest; } bool isCriticalPath() const { return alst == aest; }
bool isScheduled() const { return scheduledCpu.has_value(); } bool isScheduled() const { return scheduledCpu.has_value(); }
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; } onnx_mlir::spatial::SpatCompute getSpatCompute() const { return spatCompute; }
void setFlag(long long val) { flag = val; } void setFlag(long long val) { flag = val; }
long long getFlag() const { return flag; } long long getFlag() const { return flag; }

View File

@@ -92,18 +92,18 @@ inline T subtractOrZero(T lhs, T rhs) {
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); } inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) { inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatCompute spatCompute) {
constexpr Weight kOperationWeight = 100; constexpr Weight kOperationWeight = 100;
Weight numOperations = 0; Weight numOperations = 0;
for (auto& block : spatWeightedCompute.getBody()) for (auto& block : spatCompute.getBody())
for ([[maybe_unused]] auto& op : block) for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1)); numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight); return checkedMultiply(numOperations, kOperationWeight);
} }
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) { inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute spatCompute) {
CrossbarUsage crossbarUsage = 0; CrossbarUsage crossbarUsage = 0;
for (auto& region : spatWeightedCompute.getBody()) for (auto& region : spatCompute.getBody())
for (auto& inst : region) for (auto& inst : region)
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst)) if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1)); crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));

View File

@@ -24,30 +24,29 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
using SpatWeightedCompute = spatial::SpatWeightedCompute; using SpatCompute = spatial::SpatCompute;
struct ComputeValueResults { struct ComputeValueResults {
// Value yielded by the yieldOp SmallVector<Value> innerValues;
Value innerValue;
Value get(size_t resultIndex) const {
assert(resultIndex < innerValues.size() && "compute result index out of range");
return innerValues[resultIndex];
}
}; };
class LazyInsertComputeResult { class LazyInsertComputeResult {
using InsertPoint = mlir::IRRewriter::InsertPoint; using InsertPoint = mlir::IRRewriter::InsertPoint;
ComputeValueResults computeResults; ComputeValueResults computeResults;
Value channelValue;
bool onlyChannel; bool onlyChannel;
std::function<void(InsertPoint insertPoint)> channelSendInserter; std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter;
InsertPoint sendInsertPoint;
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
public: public:
LazyInsertComputeResult(ComputeValueResults computeValueResults, LazyInsertComputeResult(ComputeValueResults computeValueResults,
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter, std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter,
bool isOnlyChannel) bool isOnlyChannel)
: computeResults(computeValueResults), : computeResults(computeValueResults),
onlyChannel(isOnlyChannel), onlyChannel(isOnlyChannel),
channelSendInserter(nullptr),
sendInsertPoint({}),
channelNewInserter(channelNewInserter) {} channelNewInserter(channelNewInserter) {}
struct ChannelOrLocalOp { struct ChannelOrLocalOp {
@@ -57,12 +56,12 @@ public:
bool onlyChanneled() const { return onlyChannel; } bool onlyChanneled() const { return onlyChannel; }
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute currentCompute) { ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatCompute currentCompute, size_t resultIndex) {
Value innerValue = computeResults.get(resultIndex);
auto [newChannelValue, senderInserter] = channelNewInserter(); auto [channelValue, channelSendInserter] = channelNewInserter(resultIndex);
channelValue = newChannelValue; InsertPoint sendInsertPoint;
channelSendInserter = senderInserter; auto* block = innerValue.getParentBlock();
auto* block = computeResults.innerValue.getParentBlock();
if (!block->empty() && isa<spatial::SpatYieldOp>(block->back())) if (!block->empty() && isa<spatial::SpatYieldOp>(block->back()))
sendInsertPoint = InsertPoint(block, --block->end()); sendInsertPoint = InsertPoint(block, --block->end());
else else
@@ -70,28 +69,30 @@ public:
if (currentCompute) { if (currentCompute) {
for (auto& block : currentCompute.getBody()) for (auto& block : currentCompute.getBody())
if (&block == sendInsertPoint.getBlock()) if (&block == sendInsertPoint.getBlock())
return {computeResults.innerValue, false}; return {innerValue, false};
} }
channelSendInserter(sendInsertPoint); channelSendInserter(sendInsertPoint);
return {channelValue, true}; return {channelValue, true};
} }
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); } ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex) {
return getAsChannelValueAndInsertSender({}, resultIndex);
}
}; };
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> { struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
private: private:
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults; DenseMap<SpatCompute, LazyInsertComputeResult> newComputeNodeResults;
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap; DenseMap<SpatCompute, SpatCompute> oldToNewComputeMap;
DenseMap<int64_t, SpatWeightedCompute> cpuToNewComputeMap; DenseMap<int64_t, SpatCompute> cpuToNewComputeMap;
public: public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; } StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
StringRef getDescription() const override { StringRef getDescription() const override {
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total " return "Merge Spatial-Compute-Nodes in order to reduce the total "
"execution time"; "execution time";
} }
@@ -105,22 +106,22 @@ public:
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) { for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode); size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
if (!cpuToNewComputeMap.contains(cpu)) { if (!cpuToNewComputeMap.contains(cpu)) {
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes(); ValueTypeRange<ResultRange> newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
auto [newWeightedCompute, computeValueResult] = createNewComputeNode( auto [newCompute, computeValueResult] = createNewComputeNode(
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode)); currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
cpuToNewComputeMap[cpu] = newWeightedCompute; cpuToNewComputeMap[cpu] = newCompute;
newComputeNodeResults.insert( newComputeNodeResults.insert(
std::make_pair(currentComputeNode, std::make_pair(currentComputeNode,
createLazyComputeResult( createLazyComputeResult(
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
} }
else { else {
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode( auto [newCompute, computeValueResult] = mergeIntoComputeNode(
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode)); cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
newComputeNodeResults.insert( newComputeNodeResults.insert(
std::make_pair(currentComputeNode, std::make_pair(currentComputeNode,
createLazyComputeResult( createLazyComputeResult(
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
} }
} }
@@ -134,8 +135,8 @@ public:
} }
private: private:
std::pair<SpatWeightedCompute, ComputeValueResults> createNewComputeNode( std::pair<SpatCompute, ComputeValueResults> createNewComputeNode(
SpatWeightedCompute oldWeightedCompute, ValueTypeRange<ResultRange> newWeightedComputeType, bool lastCompute) { SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
func::FuncOp func = getOperation(); func::FuncOp func = getOperation();
auto loc = func.getLoc(); auto loc = func.getLoc();
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
@@ -148,50 +149,53 @@ private:
llvm::SmallVector<Type> newBBOperandType; llvm::SmallVector<Type> newBBOperandType;
llvm::SmallVector<Location> newBBLocations; llvm::SmallVector<Location> newBBLocations;
for (auto arg : oldWeightedCompute.getWeights()) for (auto arg : oldCompute.getWeights())
newComputeOperand.push_back(arg); newComputeOperand.push_back(arg);
for (auto arg : oldWeightedCompute.getInputs()) for (auto arg : oldCompute.getInputs())
if (!llvm::isa_and_present<SpatWeightedCompute>(arg.getDefiningOp())) { if (!llvm::isa_and_present<SpatCompute>(arg.getDefiningOp())) {
newComputeOperand.push_back(arg); newComputeOperand.push_back(arg);
newBBOperandType.push_back(arg.getType()); newBBOperandType.push_back(arg.getType());
newBBLocations.push_back(loc); newBBLocations.push_back(loc);
} }
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand); auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand);
rewriter.createBlock( rewriter.createBlock(
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations); &newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
newWeightedCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()}); {(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()});
auto& newBB = newWeightedCompute.getBody().front(); auto& newBB = newCompute.getBody().front();
auto& oldBB = oldWeightedCompute.getBody().front(); auto& oldBB = oldCompute.getBody().front();
rewriter.setInsertionPointToEnd(&newBB); rewriter.setInsertionPointToEnd(&newBB);
int indexNew = 0; int indexNew = 0;
size_t indexOld = oldWeightedCompute.getWeights().size(); size_t indexOld = oldCompute.getWeights().size();
size_t indexOldStart = oldWeightedCompute.getWeights().size(); size_t indexOldStart = oldCompute.getWeights().size();
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) { for (; indexOld < oldCompute.getNumOperands(); ++indexOld) {
if (!llvm::isa_and_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) { if (!llvm::isa_and_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp())) {
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++)); mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
} }
else { else {
auto argWeightCompute = auto argWeightCompute =
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp()); llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
auto argResultIndex = cast<OpResult>(oldCompute.getOperand(indexOld)).getResultNumber();
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(); auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex);
assert(isChannel == true); assert(isChannel == true);
spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp::create(
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal); rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp); mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
} }
} }
for (auto& op : oldWeightedCompute.getOps()) { for (auto& op : oldCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) { if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0)); computeValueResults.innerValues.reserve(yield.getNumOperands());
for (Value yieldOperand : yield.getOperands())
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
if (lastCompute) if (lastCompute)
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
} }
@@ -199,16 +203,18 @@ private:
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
} }
for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses())) for (auto& use : llvm::make_early_inc_range(oldCompute->getUses()))
if (isa<func::ReturnOp>(use.getOwner())) if (isa<func::ReturnOp>(use.getOwner())) {
use.assign(newWeightedCompute.getResult(0)); auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
use.assign(newCompute.getResult(resultIndex));
}
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute}); oldToNewComputeMap.insert({oldCompute, newCompute});
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults}; return {cast<SpatCompute>(newCompute), computeValueResults};
} }
std::pair<SpatWeightedCompute, ComputeValueResults> std::pair<SpatCompute, ComputeValueResults>
mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) { mergeIntoComputeNode(SpatCompute toCompute, SpatCompute fromCompute, bool lastCompute) {
func::FuncOp func = getOperation(); func::FuncOp func = getOperation();
auto loc = func.getLoc(); auto loc = func.getLoc();
IRRewriter rewriter(&getContext()); IRRewriter rewriter(&getContext());
@@ -239,14 +245,15 @@ private:
// Insert receiveOp // Insert receiveOp
rewriter.setInsertionPointToEnd(&toBB); rewriter.setInsertionPointToEnd(&toBB);
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) { for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) { if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatCompute>(arg.getDefiningOp())) {
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto argResultIndex = cast<OpResult>(arg).getResultNumber();
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal = LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute); lazyArgWeight.getAsChannelValueAndInsertSender(toCompute, argResultIndex);
if (channelOrLocal.isChannel) { if (channelOrLocal.isChannel) {
spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp receiveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data); spatial::SpatChannelReceiveOp::create(rewriter, loc, arg.getType(), channelOrLocal.data);
mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult()); mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult());
} }
else { else {
@@ -286,7 +293,9 @@ private:
}; };
for (auto& op : fromCompute.getOps()) { for (auto& op : fromCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) { if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0)); computeValueResults.innerValues.reserve(yield.getNumOperands());
for (Value yieldOperand : yield.getOperands())
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
if (lastCompute) if (lastCompute)
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
} }
@@ -299,33 +308,36 @@ private:
} }
} }
for (auto users : fromCompute->getUsers()) for (auto& use : llvm::make_early_inc_range(fromCompute->getUses()))
if (auto funcRet = dyn_cast<func::ReturnOp>(users)) if (isa<func::ReturnOp>(use.getOwner())) {
funcRet.setOperand(0, toCompute.getResult(0)); auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
use.assign(toCompute.getResult(resultIndex));
}
oldToNewComputeMap.insert({fromCompute, toCompute}); oldToNewComputeMap.insert({fromCompute, toCompute});
return {cast<SpatWeightedCompute>(toCompute), computeValueResults}; return {cast<SpatCompute>(toCompute), computeValueResults};
} }
LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute, LazyInsertComputeResult createLazyComputeResult(SpatCompute compute,
ComputeValueResults computeValueResults, ComputeValueResults computeValueResults,
bool lastCompute) { bool lastCompute) {
func::FuncOp funcOp = cast<func::FuncOp>(weightedCompute->getParentOp()); func::FuncOp funcOp = cast<func::FuncOp>(compute->getParentOp());
auto* context = &getContext(); auto* context = &getContext();
auto loc = funcOp.getLoc(); auto loc = funcOp.getLoc();
IRRewriter rewriter(context); IRRewriter rewriter(context);
rewriter.setInsertionPointToStart(&funcOp.front()); rewriter.setInsertionPointToStart(&funcOp.front());
auto savedChannelInsertPoint = rewriter.saveInsertionPoint(); auto savedChannelInsertPoint = rewriter.saveInsertionPoint();
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults]() { auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults](size_t resultIndex) {
IRRewriter rewriter(context); IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(savedChannelInsertPoint); rewriter.restoreInsertionPoint(savedChannelInsertPoint);
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context)); auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
auto channelVal = channelOp.getResult(); auto channelVal = channelOp.getResult();
auto insertVal = [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint sendInsertPoint) { auto insertVal =
[&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) {
IRRewriter rewriter(context); IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(sendInsertPoint); rewriter.restoreInsertionPoint(sendInsertPoint);
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue); auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
return spatSend; return spatSend;
}; };
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal}; std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};

View File

@@ -31,7 +31,7 @@ struct CountInstructionPass : public PassWrapper<CountInstructionPass, Operation
unsigned totalInstructionCount = 0; unsigned totalInstructionCount = 0;
unsigned computeId = 0; unsigned computeId = 0;
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) { for (auto computeOp : func.getOps<spatial::SpatCompute>()) {
unsigned instructionCount = 0; unsigned instructionCount = 0;
instructionCount += computeOp.getBody().front().getOperations().size(); instructionCount += computeOp.getBody().front().getOperations().size();
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n"; llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";

View File

@@ -26,6 +26,10 @@ STAGE_COUNT = len(STAGE_TITLES)
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation") GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
def sanitize_output_name(name):
return "".join(ch if ch.isalnum() or ch in "_.-" else "_" for ch in name[:255])
@dataclass @dataclass
class ValidationResult: class ValidationResult:
passed: bool passed: bool
@@ -205,7 +209,7 @@ def build_dump_ranges(config_path, outputs_descriptor):
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None): def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
run_command( run_command(
["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", ["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
cwd=simulator_dir, cwd=simulator_dir,
reporter=reporter, reporter=reporter,
@@ -229,7 +233,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
all_passed = True all_passed = True
rows = [] rows = []
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor): for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
csv_name = f"output{oi}_{name}.csv" csv_name = f"output{oi}_{sanitize_output_name(name)}.csv"
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape) runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64)))) max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
passed = max_diff <= threshold passed = max_diff <= threshold